Merge remote-tracking branch 'upstream/master'

This commit is contained in:
avijit-nervana 2018-08-06 13:34:24 -07:00
commit 1149ad359f
541 changed files with 23312 additions and 9987 deletions

View File

@ -1,3 +1,68 @@
# Release 1.10.0
## Major Features And Improvements
* The `tf.lite` runtime now supports `complex64`.
* Initial Bigtable integration for `tf.data`.
* Improved local run behavior in `tf.estimator.train_and_evaluate` which does not reload checkpoints for evaluation.
* `RunConfig` now sets device_filters to restrict how workers and PS can communicate. This can speed up training and ensure clean shutdowns in some situations. But if you have jobs that require communication between workers, you will have to set custom session_options in your `RunConfig`.
* Moved Distributions and Bijectors from `tf.contrib.distributions` to [Tensorflow Probability (TFP)](https://github.com/tensorflow/probability). `tf.contrib.distributions` is now deprecated and will be removed by the end of 2018.
* Adding new endpoints for existing tensorflow symbols. These endpoints are going to be the preferred endpoints going forward and may replace some of the existing endpoints in the future. See below for the complete list. New symbols have been added to the following modules: [`tf.debugging`](https://www.tensorflow.org/versions/master/api_docs/python/tf/debugging), [`tf.dtypes`](https://www.tensorflow.org/versions/master/api_docs/python/tf/dtypes), [`tf.image`](https://www.tensorflow.org/versions/master/api_docs/python/tf/image), [`tf.io`](https://www.tensorflow.org/versions/master/api_docs/python/tf/io), [`tf.linalg`](https://www.tensorflow.org/versions/master/api_docs/python/tf/linalg), [`tf.manip`](https://www.tensorflow.org/versions/master/api_docs/python/tf/manip), [`tf.math`](https://www.tensorflow.org/versions/master/api_docs/python/tf/math), [`tf.quantization`](https://www.tensorflow.org/versions/master/api_docs/python/tf/quantization), [`tf.strings`](https://www.tensorflow.org/versions/master/api_docs/python/tf/strings)
## Breaking Changes
* Prebuilt binaries are now (as of TensorFlow 1.10) built against NCCL 2.2 and no longer include NCCL in the binary install. TensorFlow usage with multiple GPUs and NCCL requires upgrade to [NCCL 2.2](https://developer.nvidia.com/nccl). See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support) and [Install TensorFlow from Sources](https://www.tensorflow.org/install/install_sources#optional_install_tensorflow_for_gpu_prerequisites).
* Starting from TensorFlow 1.11, Windows builds will use Bazel. Therefore, we will drop official support for cmake.
## Bug Fixes and Other Changes
* `tf.data`:
* `tf.contrib.data.group_by_reducer()` is now available via the public API.
* `tf.contrib.data.choose_from_datasets()` is now available via the public API.
* Adding `drop_remainder` argument to `tf.data.Dataset.batch()` and `tf.data.Dataset.padded_batch()`, deprecating tf.contrib.data.batch_and_drop_remainder()` and `tf.contrib.data.padded_batch_and_drop_remainder()`.
* `tf.estimator`:
* `Estimator`s now use custom savers included in `EstimatorSpec` scaffolds for saving SavedModels during export.
* `EstimatorSpec` will now add a default prediction output for export if no `export_output` is provided, eliminating the need to explicitly include a `PredictOutput` object in the `model_fn` for simple use-cases.
* Support sparse_combiner in canned Linear Estimators.
* Added batch normalization to `DNNClassifier`, `DNNRegressor`, and `DNNEstimator`.
* Adding ranking support for boosted trees.
* Adding center bias option for boosted trees.
* Add `synchronization` and `aggregation` args to get_variable(). These args will be used for distributed variables.
* Add `synchronization` and `aggregation` args to the layer `add_weight()` API. These args will be used for distributed variables.
* `tf.losses.*` do not add to the global collection when executing eagerly (to avoid leaking memory).
* Support different summary and checkpoint directories in `tf.train.MonitoredTrainingSession()`.
* Added IndRNN, IndyGRU, and IndyLSTM cells to `tf.contrib.rnn`.
* Add safe static factory functions for SparseTensor and convert all CHECKs to DCHECKs. Using the constructor directly is unsafe and deprecated.
* Make the Bigtable client connection pool configurable & increase the default # of connections for performance.
* Added derivative of `tf.random_gamma` with respect to the alpha parameter.
* Added derivative of `tf.igamma(a, x)` and `tf.igammac(a, x)` with respect to a.
* Modified Bessel functions of order zero and one.
* Add FillTriangular Bijector to create triangular matrices.
* Added support for Type III DCT, and `tf.spectral.idct(type=2|3)`.
* Correctly handle CuDNN RNN weight loaded when nest in `TimeDistributed`.
* Adding per-element weight support for `WALSComputePartialLhsAndRhsOp`.
* ZerosLike and OnesLike ops treated as constants by Graph Transform Tool.
* Gamma distribution and the derived distributions (Beta, Dirichlet, Student's t, inverse Gamma) now fully reparameterized.
* Java: Experimental wrapper classes to make graph generation easier. Thanks @karllessard and @kbsriram
* Build & link in secure gRPC components (switch from the insecure grpc dependency to secure grpc dependency).
* Adding new endpoints for existing tensorflow symbols. These endpoints are going to be the preferred endpoints going forward and may replace some of the existing endpoints in the future. List of new endpoints:
* New endpoints in `tf.image` namespace: `tf.image.extract_image_patches`
* New endpoints in `tf.debugging` namespace: `tf.debugging.check_numerics`, `tf.debugging.is_finite`, `tf.debugging.is_inf`, `tf.debugging.is_nan`.
* New endpoints in `tf.dtypes` namespace: `tf.dtypes.as_string`.
* New endpoints in `tf.io` namespace: `tf.io.decode_base64`, `tf.io.decode_compressed`, `tf.io.decode_json_example`, `tf.io.decode_raw`, `tf.io.encode_base64`, `tf.io.matching_files`, `tf.io.parse_tensor`, `tf.io.read_file, `tf.io.write_file`.
* New endpoints in tf.linalg namespace: `tf.linalg.cross`, `tf.linalg.tensor_diag` (corresponds to `tf.diag`), `tf.linalg.tensor_diag_part` (corresponds to `tf.diag_part`).
* New endpoints in tf.manip namespace: `tf.manip.batch_to_space_nd`, `tf.manip.gather_nd`, `tf.manip.reshape`, `tf.manip.reverse`, `tf.manip.scatter_nd`, `tf.manip.space_to_batch_nd`, `tf.manip.tile`
* New endpoints in tf.math namespace: `tf.math.acos`, `tf.math.acosh`, `tf.math.add`, `tf.math.asin`, `tf.math.asinh`, `tf.math.atan`, `tf.math.atan2`, `tf.math.atanh`, `tf.math.betainc`, `tf.math.ceil`, `tf.math.cos`, `tf.math.cosh`, `tf.math.digamma`, `tf.math.equal`, `tf.math.erfc`, `tf.math.exp`, `tf.math.expm1`, `tf.math.floor`, `tf.math.greater`, `tf.math.greater_equal`, `tf.math.igamma`, `tf.math.igammac`, `tf.math.invert_permutation`, `tf.math.less`, `tf.math.less_equal`, `tf.math.lgamma`, `tf.math.log`, `tf.math.log1p`, `tf.math.logical_and`, `tf.math.logical_not`, `tf.math.logical_or`, `tf.math.maximum`, `tf.math.minimum`, `tf.math.not_equal`, `tf.math.polygamma`, `tf.math.reciprocal`, `tf.math.rint`, `tf.math.rsqrt`, `tf.math.segment_max`, `tf.math.segment_mean`, `tf.math.segment_min`, `tf.math.segment_prod`, `tf.math.segment_sum`, `tf.math.sin`, `tf.math.sinh`, `tf.math.softplus`, `tf.math.softsign`, `tf.math.squared_difference`, `tf.math.tan`, `tf.math.unsorted_segment_max`, `tf.math.unsorted_segment_min`, `tf.math.unsorted_segment_prod`, `tf.math.unsorted_segment_sum`, `tf.math.zeta`.
* New endpoints in `tf.quantization` namespace: `tf.quantization.dequantize`, `tf.quantization.fake_quant_with_min_max_args`, `tf.quantization.fake_quant_with_min_max_args_gradient`, `tf.quantization.fake_quant_with_min_max_vars`, `tf.quantization.fake_quant_with_min_max_vars_gradient`, `tf.quantization.fake_quant_with_min_max_vars_per_channel`, `tf.quantization.fake_quant_with_min_max_vars_per_channel_gradient`.
* New endpoints in tf.strings namespace: `tf.strings.join` (corresponds to `tf.string_join`), `tf.strings.regex_replace`, `tf.strings.to_number` (corresponds to `tf.string_to_number`), `tf.strings.strip` (corresponds to `tf.string_strip`), `tf.strings.substr`, `tf.strings.to_hash_bucket` (corresponds to `tf.string_to_hash_bucket`), `tf.strings.to_hash_bucket_fast` (corresponds to `tf.string_to_hash_bucket_fast`), `tf.strings.to_hash_bucket_strong` (corresponds to `tf.string_to_hash_bucket_strong`).
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, Andrei Nigmatulin, Andrew Ginns, BjøRn Moholt, Brett Koonce, Chengzhi Chen, Chinmay Das, Christian Ertler, Christoph Boeddeker, Clayne Robison, Courtial Florian, ctiijima, Dan Douthit, Dan J, Dan Ringwalt, EFanZh, Emanuele Ballarin, eqy, Evgeniy Zheltonozhskiy, Freedom" Koan-Sin Tan, FréDéRic Branchaud-Charron, G K, gracehoney, Guillaume Klein, Guozhong Zhuang, Hsien-Yang Li, hsm207, ImSheridan, Jayaram Bobba, Jiandong Ruan, Jie, Joel Shor, Jonas Rauber, Jongmin Baek, jsawruk, Karan Kaw, Karl Lessard, karl@kubx.ca, Kb Sriram, KinmanLam, leiiwang, Li, Yiqiang, Loo Rong Jie, Mahmoud Abuzaina, Mahmoud Aslan, ManHyuk, Martin Patz, Martin Zeitler, mktozk, Mohammad Ashraf Bhuiyan, mrTsjolder, Naman Bhalla, Nick Felt, Nicolas Lopez, Niranjan Hasabnis, Nishidha Panpaliya, Nitish, nrstott, Nutti, Parag Jain, PeterLee, Philipp Jund, Rach L, Rafal Wojdyla, Roland Zimmermann, Sergei Lebedev, SneakyFish5, Soila Kavulya, Sriram Veturi, Steven Schmatz, Taehoon Lee, Tang, Wenyi, Taras Sereda, Ted Chang, Tim Zaman, Tristan Rice, tucan, vchigrin, Vikram Tiwari, Vincent, WeberXie, William D. Irons, Yan Facai (颜发才), Yong Tang, Yu Yi, Yuxin Wu, Zé ViníCius
# Release 1.9.0
## Major Features And Improvements

View File

@ -387,6 +387,7 @@ config_setting(
define_values = {
"dynamic_loaded_kernels": "true",
},
visibility = ["//visibility:public"],
)
config_setting(

View File

@ -150,8 +150,8 @@ tensorflow::Status CreateRemoteContexts(
return tensorflow::Status::OK();
}
tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
TFE_Context** ctx) {
tensorflow::Status UpdateTFE_ContextWithServerDef(
const tensorflow::ServerDef& server_def, TFE_Context* ctx) {
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
@ -165,12 +165,12 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
} \
} while (0);
string worker_name = tensorflow::strings::StrCat(
"/job:", opts->server_def.job_name(),
"/replica:0/task:", opts->server_def.task_index());
string worker_name =
tensorflow::strings::StrCat("/job:", server_def.job_name(),
"/replica:0/task:", server_def.task_index());
std::unique_ptr<tensorflow::ServerInterface> server;
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(opts->server_def, &server));
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &server));
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(server.get());
@ -202,15 +202,15 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
// Initialize remote eager workers.
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
remote_workers, rendezvous_id, opts->server_def,
remote_eager_workers.get(), opts->async, &remote_contexts));
remote_workers, rendezvous_id, server_def, remote_eager_workers.get(),
ctx->context.Async(), &remote_contexts));
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id);
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
session_name, opts->server_def, true));
session_name, server_def, true));
std::shared_ptr<tensorflow::WorkerSession> worker_session;
TF_RETURN_IF_ERROR(
@ -221,10 +221,10 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
auto* device_mgr = grpc_server->worker_env()->device_mgr;
*ctx = new TFE_Context(opts->session_options.options, opts->policy,
opts->async, device_mgr, r, std::move(server),
std::move(remote_eager_workers),
std::move(remote_device_mgr), remote_contexts);
ctx->context.InitializeRemote(
std::move(server), std::move(remote_eager_workers),
std::move(remote_device_mgr), remote_contexts, r, device_mgr);
return tensorflow::Status::OK();
#undef LOG_AND_RETURN_IF_ERROR
@ -249,15 +249,6 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
options->policy = policy;
}
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef(
TFE_ContextOptions* options, const void* proto, size_t proto_len,
TF_Status* status) {
if (!options->server_def.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
"Invalid tensorflow.ServerDef protocol buffer");
}
}
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
unsigned char async,
TF_Status* status) {
@ -267,12 +258,6 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (!opts->server_def.job_name().empty()) {
TFE_Context* ctx = nullptr;
status->status = NewRemoteAwareTFE_Context(opts, &ctx);
return ctx;
}
std::vector<tensorflow::Device*> devices;
status->status = tensorflow::DeviceFactory::AddDevices(
opts->session_options.options, "/job:localhost/replica:0/task:0",
@ -301,6 +286,20 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); }
// Set server_def on the context, possibly updating it.
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
const void* proto,
size_t proto_len,
TF_Status* status) {
tensorflow::ServerDef server_def;
if (!server_def.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
"Invalid tensorflow.ServerDef protocol buffer");
return;
}
status->status = UpdateTFE_ContextWithServerDef(server_def, ctx);
}
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
ctx->context.SetThreadLocalDevicePlacementPolicy(
@ -348,6 +347,11 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
}
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return -1;
}
int result;
status->status = h->handle->NumDims(&result);
return result;
@ -355,12 +359,22 @@ int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return -1;
}
tensorflow::int64 result;
status->status = h->handle->Dim(dim_index, &result);
return result;
}
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::Device* d = nullptr;
status->status = h->handle->OpDevice(&d);
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
@ -368,6 +382,11 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
}
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
tensorflow::Device* d = nullptr;
tensorflow::Device* op_device = nullptr;

View File

@ -81,16 +81,6 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*,
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy);
// A tensorflow.ServerDef specifies remote workers (in addition to the current
// workers name). Operations created on this context can then be executed on
// any of these remote workers by setting an appropriate device.
//
// If the following is set, all servers identified by the
// ServerDef must be up when the context is created.
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef(
TFE_ContextOptions* options, const void* proto, size_t proto_len,
TF_Status* status);
// Destroy an options object.
TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*);
@ -127,6 +117,17 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*,
unsigned char async,
TF_Status* status);
// A tensorflow.ServerDef specifies remote workers (in addition to the current
// workers name). Operations created on this context can then be executed on
// any of these remote workers by setting an appropriate device.
//
// If the following is set, all servers identified by the
// ServerDef must be up when the context is created.
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
const void* proto,
size_t proto_len,
TF_Status* status);
// Causes the calling thread to block till all ops dispatched in async mode
// have been executed. Note that "execution" here refers to kernel execution /
// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee

View File

@ -59,7 +59,6 @@ struct TFE_ContextOptions {
// true if async execution is enabled.
bool async = false;
TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT};
tensorflow::ServerDef server_def;
};
struct TFE_Context {
@ -73,23 +72,6 @@ struct TFE_Context {
default_policy),
async, std::move(device_mgr), rendezvous) {}
explicit TFE_Context(
const tensorflow::SessionOptions& opts,
TFE_ContextDevicePlacementPolicy default_policy, bool async,
tensorflow::DeviceMgr* local_device_mgr,
tensorflow::Rendezvous* rendezvous,
std::unique_ptr<tensorflow::ServerInterface> server,
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers,
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr,
const tensorflow::gtl::FlatMap<tensorflow::string, tensorflow::uint64>&
remote_contexts)
: context(opts,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
default_policy),
async, local_device_mgr, rendezvous, std::move(server),
std::move(remote_eager_workers), std::move(remote_device_mgr),
remote_contexts) {}
tensorflow::EagerContext context;
};

View File

@ -108,14 +108,14 @@ TEST(CAPI, Context) {
TF_DeleteStatus(status);
}
tensorflow::ServerDef GetServerDef(int num_tasks) {
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
server_def.set_job_name("localhost");
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name("localhost");
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
@ -124,6 +124,10 @@ tensorflow::ServerDef GetServerDef(int num_tasks) {
return server_def;
}
tensorflow::ServerDef GetServerDef(int num_tasks) {
return GetServerDef("localhost", num_tasks);
}
void TestRemoteExecute(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
@ -140,9 +144,6 @@ void TestRemoteExecute(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
@ -150,6 +151,9 @@ void TestRemoteExecute(bool async) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
const char remote_device_name[] =
@ -195,8 +199,8 @@ void TestRemoteExecute(bool async) {
TFE_DeleteOp(matmul);
TFE_ContextAsyncWait(ctx, status);
TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
@ -229,15 +233,15 @@ void TestRemoteExecuteSilentCopies(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
@ -296,6 +300,147 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true);
}
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
const std::vector<float>& expected_values) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
std::unique_ptr<float[]> actual_values(new float[expected_values.size()]);
EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t));
memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
for (int i = 0; i < expected_values.size(); i++) {
EXPECT_EQ(expected_values[i], actual_values[i])
<< "Mismatch in expected values at (zero-based) index " << i;
}
}
void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
const char* remote_device_name,
const char* local_device_name) {
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
TFE_OpSetDevice(matmul, remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* retval_task0 =
TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22});
TFE_DeleteTensorHandle(retval_task0);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(matmul);
TFE_ContextAsyncWait(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
void TestRemoteExecuteChangeServerDef(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
const char local_device_name[] =
"/job:localhost/replica:0/task:0/device:CPU:0";
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
TFE_ContextAsyncWait(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TODO(nareshmodi): Figure out how to correctly shut the server down.
worker_server.release();
// Update the server def with a new set of names (worker instead of
// localhost).
tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2);
serialized = updated_server_def.SerializeAsString();
updated_server_def.set_task_index(1);
tensorflow::Status s = tensorflow::GrpcServer::Create(
updated_server_def, tensorflow::Env::Default(), &worker_server);
ASSERT_TRUE(s.ok()) << s.error_message();
ASSERT_TRUE(worker_server->Start().ok());
TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Create a new tensor_handle.
TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle();
// Check that copying it to the old remote device (named localhost) fails.
TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Copying and executing on the new remote device works.
const char new_remote_device_name[] =
"/job:worker/replica:0/task:1/device:CPU:0";
const char new_local_device_name[] =
"/job:worker/replica:0/task:0/device:CPU:0";
auto* h0_task1_new = TFE_TensorHandleCopyToDevice(
h0_task0_new, ctx, new_remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(h0_task0_new);
TFE_DeleteTensorHandle(h0_task1_new);
CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
new_local_device_name);
TFE_ContextAsyncWait(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
// TODO(nareshmodi): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteChangeServerDef) {
TestRemoteExecuteChangeServerDef(false);
}
TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
TestRemoteExecuteChangeServerDef(true);
}
TEST(CAPI, TensorHandle) {
TFE_TensorHandle* h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
@ -615,6 +760,42 @@ void SetAndGetOpDevices(bool async) {
TF_DeleteStatus(status);
}
TEST(CAPI, TensorHandleNullptr) {
TFE_TensorHandle* h = nullptr;
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(t, nullptr);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));
TF_SetStatus(status.get(), TF_OK, "");
const char* device_name = TFE_TensorHandleDeviceName(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(device_name, nullptr);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));
TF_SetStatus(status.get(), TF_OK, "");
int num_dims = TFE_TensorHandleNumDims(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(num_dims, -1);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));
TF_SetStatus(status.get(), TF_OK, "");
int dim = TFE_TensorHandleDim(h, 0, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(dim, -1);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));
}
void Execute_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();

View File

@ -121,6 +121,7 @@ cc_library(
deps = [
":array_grad",
":data_flow_grad",
":image_grad",
":math_grad",
":nn_grad",
],
@ -331,6 +332,36 @@ tf_cc_test(
],
)
cc_library(
name = "image_grad",
srcs = ["gradients/image_grad.cc"],
deps = [
":cc_ops",
":cc_ops_internal",
":grad_op_registry",
":gradients",
],
alwayslink = 1,
)
tf_cc_test(
name = "gradients_image_grad_test",
srcs = ["gradients/image_grad_test.cc"],
deps = [
":cc_ops",
":client_session",
":grad_op_registry",
":grad_testutil",
":gradient_checker",
":image_grad",
":testutil",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "math_grad",
srcs = ["gradients/math_grad.cc"],

View File

@ -127,4 +127,22 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs,
target_node_names, outputs, run_metadata);
}
Status ClientSession::MakeCallable(const CallableOptions& callable_options,
CallableHandle* out_handle) {
TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
return impl()->session_->MakeCallable(callable_options, out_handle);
}
Status ClientSession::RunCallable(CallableHandle handle,
const std::vector<Tensor>& feed_tensors,
std::vector<Tensor>* fetch_tensors,
RunMetadata* run_metadata) {
return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors,
run_metadata);
}
Status ClientSession::ReleaseCallable(CallableHandle handle) {
return impl()->session_->ReleaseCallable(handle);
}
} // end namespace tensorflow

View File

@ -87,7 +87,33 @@ class ClientSession {
const std::vector<Operation>& run_outputs,
std::vector<Tensor>* outputs, RunMetadata* run_metadata) const;
// TODO(keveman): Add support for partial run.
/// \brief A handle to a subgraph, created with
/// `ClientSession::MakeCallable()`.
typedef int64 CallableHandle;
/// \brief Creates a `handle` for invoking the subgraph defined by
/// `callable_options`.
/// NOTE: This API is still experimental and may change.
Status MakeCallable(const CallableOptions& callable_options,
CallableHandle* out_handle);
/// \brief Invokes the subgraph named by `handle` with the given options and
/// input tensors.
///
/// The order of tensors in `feed_tensors` must match the order of names in
/// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will
/// match the order of names in `CallableOptions::fetch()` when this subgraph
/// was created.
/// NOTE: This API is still experimental and may change.
Status RunCallable(CallableHandle handle,
const std::vector<Tensor>& feed_tensors,
std::vector<Tensor>* fetch_tensors,
RunMetadata* run_metadata);
/// \brief Releases resources associated with the given `handle` in this
/// session.
/// NOTE: This API is still experimental and may change.
Status ReleaseCallable(CallableHandle handle);
private:
class Impl;

View File

@ -95,5 +95,26 @@ TEST(ClientSessionTest, MultiThreaded) {
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({-1, 2}, {2}));
}
TEST(ClientSessionTest, Callable) {
Scope root = Scope::NewRootScope();
auto a = Placeholder(root, DT_INT32);
auto b = Placeholder(root, DT_INT32);
auto c = Add(root, a, b);
ClientSession session(root);
std::vector<Tensor> outputs;
CallableOptions options;
options.add_feed(a.node()->name());
options.add_feed(b.node()->name());
options.add_fetch(c.node()->name());
ClientSession::CallableHandle callable;
TF_CHECK_OK(session.MakeCallable(options, &callable));
TF_EXPECT_OK(session.RunCallable(
callable, {test::AsTensor<int>({1}, {}), test::AsTensor<int>({41}, {})},
&outputs, nullptr));
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({42}, {}));
TF_EXPECT_OK(session.ReleaseCallable(callable));
}
} // namespace
} // namespace tensorflow

View File

@ -247,7 +247,7 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs,
auto y_pos_flat = y_pos[y_idx].flat<Y_T>();
auto y_neg_flat = y_neg[y_idx].flat<Y_T>();
const int64 y_size = y_shapes[y_idx].num_elements();
const Y_T scale = Y_T{2 * delta};
const Y_T scale = 2 * delta;
auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix<JAC_T>();
for (int c = 0; c < y_size; ++c) {
SetJacobian<Y_T, JAC_T>(&jacobian, r * x_stride + unit_dimension,
@ -351,7 +351,14 @@ Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs,
auto jac_n = jacobian_ns[i].matrix<JAC_T>();
for (int r = 0; r < jacobian_ts[i].dim_size(0); ++r) {
for (int c = 0; c < jacobian_ts[i].dim_size(1); ++c) {
*max_error = std::max(*max_error, std::fabs(jac_t(r, c) - jac_n(r, c)));
auto cur_error = std::fabs(jac_t(r, c) - jac_n(r, c));
// Treat any NaN as max_error and immediately return.
// (Note that std::max may ignore NaN arguments.)
if (std::isnan(cur_error)) {
*max_error = cur_error;
return Status::OK();
}
*max_error = std::max(*max_error, cur_error);
}
}
}
@ -409,6 +416,7 @@ Status ComputeGradientError(const Scope& scope, const Output& x,
const Output& y, const TensorShape& y_shape, JAC_T* max_error);
INSTANTIATE_GRAD_ERR_TYPE(float, float, float);
INSTANTIATE_GRAD_ERR_TYPE(double, float, double);
INSTANTIATE_GRAD_ERR_TYPE(double, double, double);
INSTANTIATE_GRAD_ERR_TYPE(complex64, float, float);
INSTANTIATE_GRAD_ERR_TYPE(float, complex64, float);

View File

@ -28,12 +28,14 @@ namespace {
using ops::Complex;
using ops::Const;
using ops::Div;
using ops::MatMul;
using ops::Placeholder;
using ops::Real;
using ops::Split;
using ops::Square;
using ops::Stack;
using ops::Sub;
using ops::Unstack;
TEST(GradientCheckerTest, BasicFloat) {
@ -104,6 +106,20 @@ TEST(GradientCheckerTest, Complex64ToFloat) {
EXPECT_LT(max_error, 1e-4);
}
// When calculating gradients that are undefined, test we get NaN
// as the computed error rather than 0.
TEST(GradientCheckerTest, BasicNan) {
Scope scope = Scope::NewRootScope();
TensorShape shape({2, 4, 3});
auto x = Placeholder(scope, DT_FLOAT, Placeholder::Shape(shape));
// y = x/(x-x) should always return NaN
auto y = Div(scope, x, Sub(scope, x, x));
float max_error;
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
scope, {x}, {shape}, {y}, {shape}, &max_error)));
EXPECT_TRUE(std::isnan(max_error));
}
TEST(GradientCheckerTest, MatMulGrad) {
Scope scope = Scope::NewRootScope();

View File

@ -0,0 +1,74 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/ops/image_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
namespace tensorflow {
namespace ops {
namespace {
Status ResizeNearestNeighborGradHelper(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
bool align_corners;
TF_RETURN_IF_ERROR(
GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners));
// The internal gradient implementation needs the shape of the input image.
// x_shape = shape(x)[1:3]
// = slice(shape(x), {1}, {3 - 1})
auto x_shape = Slice(scope, Shape(scope, op.input(0)), {1}, {2});
grad_outputs->push_back(internal::ResizeNearestNeighborGrad(
scope, grad_inputs[0], x_shape,
internal::ResizeNearestNeighborGrad::AlignCorners(align_corners)));
grad_outputs->push_back(NoGradient());
return scope.status();
}
REGISTER_GRADIENT_OP("ResizeNearestNeighbor", ResizeNearestNeighborGradHelper);
Status ResizeBilinearGradHelper(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
bool align_corners;
TF_RETURN_IF_ERROR(
GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners));
grad_outputs->push_back(internal::ResizeBilinearGrad(
scope, grad_inputs[0], op.input(0),
internal::ResizeBilinearGrad::AlignCorners(align_corners)));
grad_outputs->push_back(NoGradient());
return scope.status();
}
REGISTER_GRADIENT_OP("ResizeBilinear", ResizeBilinearGradHelper);
Status ResizeBicubicGradHelper(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
bool align_corners;
TF_RETURN_IF_ERROR(
GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners));
grad_outputs->push_back(internal::ResizeBicubicGrad(
scope, grad_inputs[0], op.input(0),
internal::ResizeBicubicGrad::AlignCorners(align_corners)));
grad_outputs->push_back(NoGradient());
return scope.status();
}
REGISTER_GRADIENT_OP("ResizeBicubic", ResizeBicubicGradHelper);
} // anonymous namespace
} // namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,157 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/gradient_checker.h"
#include "tensorflow/cc/framework/testutil.h"
#include "tensorflow/cc/gradients/grad_testutil.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
using ops::Const;
using ops::ResizeBicubic;
using ops::ResizeBilinear;
using ops::ResizeNearestNeighbor;
class ImageGradTest : public ::testing::Test {
protected:
ImageGradTest() : scope_(Scope::NewRootScope()) {}
enum OpType { RESIZE_NEAREST, RESIZE_BILINEAR, RESIZE_BICUBIC };
template <typename T>
Tensor MakeData(const TensorShape& data_shape) {
DataType data_type = DataTypeToEnum<T>::v();
Tensor data(data_type, data_shape);
auto data_flat = data.flat<T>();
for (int i = 0; i < data_flat.size(); ++i) {
data_flat(i) = T(i);
}
return data;
}
template <typename T>
void MakeOp(const OpType op_type, const Tensor& x_data, const Input& y_shape,
const bool align_corners, Output* x, Output* y) {
*x = Const<T>(scope_, x_data);
switch (op_type) {
case RESIZE_NEAREST:
*y = ResizeNearestNeighbor(
scope_, *x, y_shape,
ResizeNearestNeighbor::AlignCorners(align_corners));
return;
case RESIZE_BILINEAR:
*y = ResizeBilinear(scope_, *x, y_shape,
ResizeBilinear::AlignCorners(align_corners));
return;
case RESIZE_BICUBIC:
*y = ResizeBicubic(scope_, *x, y_shape,
ResizeBicubic::AlignCorners(align_corners));
return;
}
assert(false);
}
template <typename T>
void TestResizedShapeForType(const OpType op_type, const bool align_corners) {
TensorShape x_shape({1, 2, 2, 1});
Tensor x_data = MakeData<T>(x_shape);
Output x, y;
MakeOp<T>(op_type, x_data, {4, 6}, align_corners, &x, &y);
ClientSession session(scope_);
std::vector<Tensor> outputs;
TF_ASSERT_OK(session.Run({y}, &outputs));
EXPECT_EQ(outputs.size(), 1);
EXPECT_EQ(outputs[0].shape(), TensorShape({1, 4, 6, 1}));
}
void TestResizedShape(OpType op_type) {
for (const bool align_corners : {true, false}) {
TestResizedShapeForType<Eigen::half>(op_type, align_corners);
TestResizedShapeForType<float>(op_type, align_corners);
TestResizedShapeForType<double>(op_type, align_corners);
}
}
template <typename X_T, typename Y_T, typename JAC_T>
void TestResizeToSmallerAndAlign(const OpType op_type,
const bool align_corners) {
TensorShape x_shape({1, 4, 6, 1});
Tensor x_data = MakeData<X_T>(x_shape);
Output x, y;
MakeOp<X_T>(op_type, x_data, {2, 3}, align_corners, &x, &y);
JAC_T max_error;
TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, JAC_T>(
scope_, x, x_data, y, {1, 2, 3, 1}, &max_error)));
EXPECT_LT(max_error, 1e-3);
}
template <typename X_T, typename Y_T, typename JAC_T>
void TestResizeToLargerAndAlign(const OpType op_type,
const bool align_corners) {
TensorShape x_shape({1, 2, 3, 1});
Tensor x_data = MakeData<X_T>(x_shape);
Output x, y;
MakeOp<X_T>(op_type, x_data, {4, 6}, align_corners, &x, &y);
JAC_T max_error;
TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, JAC_T>(
scope_, x, x_data, y, {1, 4, 6, 1}, &max_error)));
EXPECT_LT(max_error, 1e-3);
}
template <typename X_T, typename Y_T, typename JAC_T>
void TestResize(OpType op_type) {
for (const bool align_corners : {true, false}) {
TestResizeToSmallerAndAlign<X_T, Y_T, JAC_T>(op_type, align_corners);
TestResizeToLargerAndAlign<X_T, Y_T, JAC_T>(op_type, align_corners);
}
}
Scope scope_;
};
TEST_F(ImageGradTest, TestNearestNeighbor) {
TestResizedShape(RESIZE_NEAREST);
TestResize<float, float, float>(RESIZE_NEAREST);
TestResize<double, double, double>(RESIZE_NEAREST);
}
TEST_F(ImageGradTest, TestBilinear) {
TestResizedShape(RESIZE_BILINEAR);
TestResize<float, float, float>(RESIZE_BILINEAR);
// Note that Y_T is always float for this op. We choose
// double for the jacobian to capture the higher precision
// between X_T and Y_T.
TestResize<double, float, double>(RESIZE_BILINEAR);
}
TEST_F(ImageGradTest, TestBicubic) {
TestResizedShape(RESIZE_BICUBIC);
TestResize<float, float, float>(RESIZE_BICUBIC);
// Note that Y_T is always float for this op. We choose
// double for the jacobian to capture the higher precision
// between X_T and Y_T.
TestResize<double, float, double>(RESIZE_BICUBIC);
}
} // namespace
} // namespace tensorflow

View File

@ -8,28 +8,6 @@ load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
# Optional runtime utilities for use by code generated by tfcompile.
cc_library(
name = "runtime",
srcs = ["runtime.cc"],
hdrs = ["runtime.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework_lite",
],
)
tf_cc_test(
name = "runtime_test",
srcs = ["runtime_test.cc"],
deps = [
":runtime",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
# Don't depend on this directly; this is only used for the benchmark test
# generated by tf_library.
cc_library(
@ -53,9 +31,9 @@ cc_library(
],
deps = [
":embedded_protocol_buffers",
":runtime", # needed by codegen to print aligned_buffer_bytes
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/tf2xla:tf2xla_proto",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
@ -238,7 +216,6 @@ test_suite(
tests = [
":benchmark_test",
":codegen_test",
":runtime_test",
":test_graph_tfadd_test",
":test_graph_tfunknownop2_test",
":test_graph_tfunknownop3_test",

View File

@ -20,7 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
#include "tensorflow/compiler/aot/runtime.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/compiler/tf2xla/str_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/service/compiler.h"
@ -303,10 +303,10 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
const std::vector<intptr_t> iarg(arg_sizes.begin(), arg_sizes.end());
const std::vector<intptr_t> itemp(temp_sizes.begin(), temp_sizes.end());
const size_t arg_bytes_aligned =
runtime::aligned_buffer_bytes(iarg.data(), iarg.size());
cpu_function_runtime::AlignedBufferBytes(iarg.data(), iarg.size());
const size_t arg_bytes_total = total_buffer_bytes(iarg.data(), iarg.size());
const size_t temp_bytes_aligned =
runtime::aligned_buffer_bytes(itemp.data(), itemp.size());
cpu_function_runtime::AlignedBufferBytes(itemp.data(), itemp.size());
const size_t temp_bytes_total =
total_buffer_bytes(itemp.data(), itemp.size());

View File

@ -16,339 +16,365 @@ tf_library(
)
"""
load("//tensorflow:tensorflow.bzl",
"if_android", "tf_cc_test", "tf_copts")
load(
"//tensorflow:tensorflow.bzl",
"if_android",
"tf_cc_test",
"tf_copts",
)
def tf_library(name, graph, config,
freeze_checkpoint=None, freeze_saver=None,
cpp_class=None, gen_test=True, gen_benchmark=True,
visibility=None, testonly=None,
tfcompile_flags=None,
tfcompile_tool="//tensorflow/compiler/aot:tfcompile",
include_standard_runtime_deps=True,
enable_xla_hlo_profiling=False, deps=None, tags=None):
"""Runs tfcompile to compile a TensorFlow graph into executable code.
def tf_library(
name,
graph,
config,
freeze_checkpoint = None,
freeze_saver = None,
cpp_class = None,
gen_test = True,
gen_benchmark = True,
visibility = None,
testonly = None,
tfcompile_flags = None,
tfcompile_tool = "//tensorflow/compiler/aot:tfcompile",
include_standard_runtime_deps = True,
enable_xla_hlo_profiling = False,
deps = None,
tags = None):
"""Runs tfcompile to compile a TensorFlow graph into executable code.
Given an invocation of tf_library(name="foo", ...), generates the following
build targets:
foo: A cc_library containing the generated header and computation.
foo_test: A cc_test with simple tests and benchmarks. Only created if
gen_test=True.
foo_benchmark: A cc_binary that runs a minimal-dependency benchmark, useful
for mobile devices or other platforms that can't compile the
full test libraries. Only created if gen_benchmark=True.
Given an invocation of tf_library(name="foo", ...), generates the following
build targets:
foo: A cc_library containing the generated header and
computation.
foo_test: A cc_test with simple tests and benchmarks. Only created if
gen_test=True.
foo_benchmark: A cc_binary that runs a minimal-dependency benchmark,
useful for mobile devices or other platforms that can't
compile the full test libraries. Only created if
gen_benchmark=True.
The output header is called <name>.h.
Args:
name: The name of the build rule.
graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it
is expected to be in the human-readable proto text format, otherwise it is
expected to be in the proto binary format.
config: File containing tensorflow.tf2xla.Config proto. If the file ends
in '.pbtxt' it is expected to be in the human-readable proto text format,
otherwise it is expected to be in the proto binary format.
freeze_checkpoint: If provided, run freeze_graph with this checkpoint to
convert variables into constants.
freeze_saver: If provided, run freeze_graph with this saver, in SaverDef
binary form, to convert variables into constants.
cpp_class: The name of the generated C++ class, wrapping the generated
function. The syntax of this flag is
[[<optional_namespace>::],...]<class_name>. This mirrors the C++ syntax
for referring to a class, where multiple namespaces may precede the class
name, separated by double-colons. The class will be generated in the
given namespace(s), or if no namespaces are given, within the global
namespace.
gen_test: If True, also generate a cc_test rule that builds a simple
test and benchmark.
gen_benchmark: If True, also generate a binary with a simple benchmark.
Unlike the output of gen_test, this benchmark can be run on android.
visibility: Bazel build visibility.
testonly: Bazel testonly attribute.
tfcompile_flags: Extra flags to pass to tfcompile to control compilation.
tfcompile_tool: The tfcompile binary. A non-default can be passed to
use a tfcompile built with extra dependencies.
include_standard_runtime_deps: If True, the standard list of kernel/runtime
deps is added to deps. If False, deps must contain the full set of deps
needed by the generated library.
enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program,
and emit metadata that lets us pretty-print the gathered profile counters.
deps: a list of deps to include on the build rules for the generated
library, added to the standard deps if standard_runtime_deps is True.
tags: tags to apply to subsidiary build rules.
Args:
name: The name of the build rule.
graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt'
it is expected to be in the human-readable proto text format, otherwise
it is expected to be in the proto binary format.
config: File containing tensorflow.tf2xla.Config proto. If the file ends
in '.pbtxt' it is expected to be in the human-readable proto text
format, otherwise it is expected to be in the proto binary format.
freeze_checkpoint: If provided, run freeze_graph with this checkpoint to
convert variables into constants.
freeze_saver: If provided, run freeze_graph with this saver, in SaverDef
binary form, to convert variables into constants.
cpp_class: The name of the generated C++ class, wrapping the generated
function. The syntax of this flag is
[[<optional_namespace>::],...]<class_name>. This mirrors the C++ syntax
for referring to a class, where multiple namespaces may precede the
class name, separated by double-colons. The class will be generated in
the given namespace(s), or if no namespaces are given, within the global
namespace.
gen_test: If True, also generate a cc_test rule that builds a simple
test and benchmark.
gen_benchmark: If True, also generate a binary with a simple benchmark.
Unlike the output of gen_test, this benchmark can be run on android.
visibility: Bazel build visibility.
testonly: Bazel testonly attribute.
tfcompile_flags: Extra flags to pass to tfcompile to control compilation.
tfcompile_tool: The tfcompile binary. A non-default can be passed to
use a tfcompile built with extra dependencies.
include_standard_runtime_deps: If True, the standard list of
kernel/runtime deps is added to deps. If False, deps must contain the
full set of deps needed by the generated library.
enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated
program, and emit metadata that lets us pretty-print the gathered
profile counters.
deps: a list of deps to include on the build rules for the generated
library, added to the standard deps if standard_runtime_deps is True.
tags: tags to apply to subsidiary build rules.
"""
if not cpp_class:
fail("cpp_class must be specified")
The output header is called <name>.h.
"""
if not cpp_class:
fail("cpp_class must be specified")
tfcompile_graph = graph
if freeze_checkpoint or freeze_saver:
if not freeze_checkpoint:
fail("freeze_checkpoint must be specified when freeze_saver is " +
"specified")
tfcompile_graph = graph
if freeze_checkpoint or freeze_saver:
if not freeze_checkpoint:
fail("freeze_checkpoint must be specified when freeze_saver is specified")
freeze_name = "freeze_" + name
freeze_file = freeze_name + ".pb"
freeze_name = "freeze_" + name
freeze_file = freeze_name + ".pb"
# First run tfcompile to generate the list of out_nodes.
out_nodes_file = "out_nodes_" + freeze_name
native.genrule(
name = ("gen_" + out_nodes_file),
srcs = [config],
outs = [out_nodes_file],
cmd = ("$(location " + tfcompile_tool + ")" +
" --config=$(location " + config + ")" +
" --dump_fetch_nodes > $@"),
tools = [tfcompile_tool],
# Run tfcompile on the build host, rather than forge, since it's
# typically way faster on the local machine.
local = 1,
tags = tags,
)
# First run tfcompile to generate the list of out_nodes.
out_nodes_file = "out_nodes_" + freeze_name
# Now run freeze_graph to convert variables into constants.
freeze_args = (
" --input_graph=$(location " + graph + ")" +
" --checkpoint_version=1" +
" --input_binary=" + str(not graph.endswith(".pbtxt")) +
" --input_checkpoint=$(location " + freeze_checkpoint + ")" +
" --output_graph=$(location " + freeze_file + ")" +
" --output_node_names=$$(<$(location " + out_nodes_file +
"))"
)
freeze_saver_srcs = []
if freeze_saver:
freeze_args += " --input_saver=$(location " + freeze_saver + ")"
freeze_saver_srcs += [freeze_saver]
native.genrule(
name = freeze_name,
srcs = [
graph,
freeze_checkpoint,
out_nodes_file,
] + freeze_saver_srcs,
outs = [freeze_file],
cmd = ("$(location " +
"//tensorflow/python/tools:freeze_graph)" +
freeze_args),
tools = ["//tensorflow/python/tools:freeze_graph"],
tags = tags,
)
tfcompile_graph = freeze_file
# Rule that runs tfcompile to produce the header and object file.
header_file = name + ".h"
metadata_object_file = name + "_tfcompile_metadata.o"
function_object_file = name + "_tfcompile_function.o"
ep = ("__" + native.package_name() + "__" + name).replace("/", "_")
if type(tfcompile_flags) == type(""):
flags = tfcompile_flags
else:
flags = " ".join([
"'" + arg.replace("'", "'\\''") + "'"
for arg in (tfcompile_flags or [])
])
if enable_xla_hlo_profiling:
profiling_flag = "--xla_hlo_profile"
else:
profiling_flag = ""
native.genrule(
name=("gen_" + out_nodes_file),
srcs=[config],
outs=[out_nodes_file],
cmd=("$(location " + tfcompile_tool + ")" +
" --config=$(location " + config + ")" +
" --dump_fetch_nodes > $@"),
tools=[tfcompile_tool],
# Run tfcompile on the build host, rather than forge, since it's
# typically way faster on the local machine.
local=1,
tags=tags,
)
# Now run freeze_graph to convert variables into constants.
freeze_args = (" --input_graph=$(location " + graph + ")" +
" --checkpoint_version=1" +
" --input_binary=" + str(not graph.endswith(".pbtxt")) +
" --input_checkpoint=$(location " + freeze_checkpoint + ")" +
" --output_graph=$(location " + freeze_file + ")" +
" --output_node_names=$$(<$(location " + out_nodes_file +
"))")
freeze_saver_srcs = []
if freeze_saver:
freeze_args += " --input_saver=$(location " + freeze_saver + ")"
freeze_saver_srcs += [freeze_saver]
native.genrule(
name=freeze_name,
srcs=[
graph,
freeze_checkpoint,
out_nodes_file,
] + freeze_saver_srcs,
outs=[freeze_file],
cmd=("$(location //tensorflow/python/tools:freeze_graph)" +
freeze_args),
tools=["//tensorflow/python/tools:freeze_graph"],
tags=tags,
)
tfcompile_graph = freeze_file
# Rule that runs tfcompile to produce the header and object file.
header_file = name + ".h"
metadata_object_file = name + "_tfcompile_metadata.o"
function_object_file = name + "_tfcompile_function.o"
ep = ("__" + native.package_name() + "__" + name).replace("/", "_")
if type(tfcompile_flags) == type(""):
flags = tfcompile_flags
else:
flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])])
if enable_xla_hlo_profiling:
profiling_flag = "--xla_hlo_profile"
else:
profiling_flag = ""
native.genrule(
name=("gen_" + name),
srcs=[
tfcompile_graph,
config,
],
outs=[
header_file,
metadata_object_file,
function_object_file,
],
cmd=("$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
" --config=$(location " + config + ")" +
" --entry_point=" + ep +
" --cpp_class=" + cpp_class +
" --target_triple=" + target_llvm_triple() +
" --out_header=$(@D)/" + header_file +
" --out_metadata_object=$(@D)/" + metadata_object_file +
" --out_function_object=$(@D)/" + function_object_file +
" " + flags + " " + profiling_flag),
tools=[tfcompile_tool],
visibility=visibility,
testonly=testonly,
# Run tfcompile on the build host since it's typically faster on the local
# machine.
#
# Note that setting the local=1 attribute on a *test target* causes the
# test infrastructure to skip that test. However this is a genrule, not a
# test target, and runs with --genrule_strategy=forced_forge, meaning the
# local=1 attribute is ignored, and the genrule is still run.
#
# https://www.bazel.io/versions/master/docs/be/general.html#genrule
local=1,
tags=tags,
)
# Rule that runs tfcompile to produce the SessionModule proto, useful for
# debugging. TODO(b/64813587): Once the SessionModule proto is
# deterministic, move this into the main rule above.
session_module_pb = name + "_session_module.pb"
native.genrule(
name=(name + "_session_module"),
srcs=[
tfcompile_graph,
config,
],
outs=[
session_module_pb,
],
cmd=("$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
" --config=$(location " + config + ")" +
" --entry_point=" + ep +
" --cpp_class=" + cpp_class +
" --target_triple=" + target_llvm_triple() +
" --out_session_module=$(@D)/" + session_module_pb +
" " + flags),
tools=[tfcompile_tool],
visibility=visibility,
testonly=testonly,
local=1,
tags=tags,
)
# The cc_library rule packaging up the header and object file, and needed
# kernel implementations.
need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1)
native.cc_library(
name=name,
srcs=[function_object_file, metadata_object_file],
hdrs=[header_file],
visibility=visibility,
testonly=testonly,
deps = [
# These deps are required by all tf_library targets even if
# include_standard_runtime_deps is False. Without them, the
# generated code will fail to compile.
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
"//tensorflow/core:framework_lite",
] + (need_xla_data_proto and [
# If we're generating the program shape, we must depend on the proto.
"//tensorflow/compiler/xla:xla_data_proto",
] or []) + (enable_xla_hlo_profiling and [
"//tensorflow/compiler/xla/service:hlo_profile_printer_data"
] or []) + (include_standard_runtime_deps and [
# TODO(cwhipkey): only depend on kernel code that the model actually needed.
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
"//third_party/eigen3",
] or []) + (deps or []),
tags=tags,
)
# Variables used for gen_test and gen_benchmark.
no_ns_name = ""
cpp_class_split = cpp_class.rsplit("::", maxsplit=2)
if len(cpp_class_split) == 1:
no_ns_name = cpp_class_split[0]
else:
no_ns_name = cpp_class_split[1]
sed_replace = (
"-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " +
"-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " +
"-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" ")
if gen_test:
test_name = name + "_test"
test_file = test_name + ".cc"
# Rule to rewrite test.cc to produce the test_file.
native.genrule(
name=("gen_" + test_name),
testonly=1,
srcs=[
"//tensorflow/compiler/aot:test.cc",
header_file,
name = ("gen_" + name),
srcs = [
tfcompile_graph,
config,
],
outs=[test_file],
cmd=("sed " + sed_replace +
" $(location //tensorflow/compiler/aot:test.cc) " +
"> $(OUTS)"),
tags=tags,
outs = [
header_file,
metadata_object_file,
function_object_file,
],
cmd = ("$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
" --config=$(location " + config + ")" +
" --entry_point=" + ep +
" --cpp_class=" + cpp_class +
" --target_triple=" + target_llvm_triple() +
" --out_header=$(@D)/" + header_file +
" --out_metadata_object=$(@D)/" + metadata_object_file +
" --out_function_object=$(@D)/" + function_object_file +
" " + flags + " " + profiling_flag),
tools = [tfcompile_tool],
visibility = visibility,
testonly = testonly,
# Run tfcompile on the build host since it's typically faster on the
# local machine.
#
# Note that setting the local=1 attribute on a *test target* causes the
# test infrastructure to skip that test. However this is a genrule, not
# a test target, and runs with --genrule_strategy=forced_forge, meaning
# the local=1 attribute is ignored, and the genrule is still run.
#
# https://www.bazel.io/versions/master/docs/be/general.html#genrule
local = 1,
tags = tags,
)
# The cc_test rule for the generated code. To ensure that this works
# reliably across build configurations, we must use tf_cc_test instead of
# native.cc_test. This is related to how we build
# //tensorflow/core:lib -- see the note in tensorflow/core/BUILD
# for more details.
tf_cc_test(
name=test_name,
srcs=[test_file],
deps=[
":" + name,
"//tensorflow/compiler/aot:runtime",
"//tensorflow/compiler/aot:tf_library_test_main",
"//tensorflow/compiler/xla:executable_run_options",
# Rule that runs tfcompile to produce the SessionModule proto, useful for
# debugging. TODO(b/64813587): Once the SessionModule proto is
# deterministic, move this into the main rule above.
session_module_pb = name + "_session_module.pb"
native.genrule(
name = (name + "_session_module"),
srcs = [
tfcompile_graph,
config,
],
outs = [
session_module_pb,
],
cmd = ("$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
" --config=$(location " + config + ")" +
" --entry_point=" + ep +
" --cpp_class=" + cpp_class +
" --target_triple=" + target_llvm_triple() +
" --out_session_module=$(@D)/" + session_module_pb +
" " + flags),
tools = [tfcompile_tool],
visibility = visibility,
testonly = testonly,
local = 1,
tags = tags,
)
# The cc_library rule packaging up the header and object file, and needed
# kernel implementations.
need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1)
native.cc_library(
name = name,
srcs = [function_object_file, metadata_object_file],
hdrs = [header_file],
visibility = visibility,
testonly = testonly,
deps = [
# These deps are required by all tf_library targets even if
# include_standard_runtime_deps is False. Without them, the
# generated code will fail to compile.
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
"//tensorflow/core:framework_lite",
] + (need_xla_data_proto and [
# If we're generating the program shape, we must depend on the
# proto.
"//tensorflow/compiler/xla:xla_data_proto",
] or []) + (enable_xla_hlo_profiling and [
"//tensorflow/compiler/xla/service:hlo_profile_printer_data",
] or []) + (include_standard_runtime_deps and [
# TODO(cwhipkey): only depend on kernel code that the model actually
# needed.
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
"//third_party/eigen3",
"//tensorflow/core:lib",
"//tensorflow/core:test",
] or []) + (deps or []),
tags = tags,
)
# Variables used for gen_test and gen_benchmark.
cpp_class_split = cpp_class.rsplit("::", maxsplit = 2)
if len(cpp_class_split) == 1:
no_ns_name = cpp_class_split[0]
else:
no_ns_name = cpp_class_split[1]
sed_replace = (
"-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " +
"-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " +
"-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" "
)
if gen_test:
test_name = name + "_test"
test_file = test_name + ".cc"
# Rule to rewrite test.cc to produce the test_file.
native.genrule(
name = ("gen_" + test_name),
testonly = 1,
srcs = [
"//tensorflow/compiler/aot:test.cc",
header_file,
],
tags=tags,
)
outs = [test_file],
cmd = (
"sed " + sed_replace +
" $(location //tensorflow/compiler/aot:test.cc) " +
"> $(OUTS)"
),
tags = tags,
)
if gen_benchmark:
benchmark_name = name + "_benchmark"
benchmark_file = benchmark_name + ".cc"
benchmark_main = ("//tensorflow/compiler/aot:" +
"benchmark_main.template")
# The cc_test rule for the generated code. To ensure that this works
# reliably across build configurations, we must use tf_cc_test instead
# of native.cc_test. This is related to how we build
# //tensorflow/core:lib -- see the note in
# tensorflow/core/BUILD for more details.
tf_cc_test(
name = test_name,
srcs = [test_file],
deps = [
":" + name,
"//tensorflow/compiler/aot:tf_library_test_main",
"//tensorflow/compiler/xla:executable_run_options",
"//third_party/eigen3",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
tags = tags,
)
# Rule to rewrite benchmark.cc to produce the benchmark_file.
native.genrule(
name=("gen_" + benchmark_name),
srcs=[
benchmark_main,
header_file,
],
testonly = testonly,
outs=[benchmark_file],
cmd=("sed " + sed_replace +
" $(location " + benchmark_main + ") " +
"> $(OUTS)"),
tags=tags,
)
if gen_benchmark:
benchmark_name = name + "_benchmark"
benchmark_file = benchmark_name + ".cc"
benchmark_main = ("//tensorflow/compiler/aot:" +
"benchmark_main.template")
# The cc_benchmark rule for the generated code. This does not need the
# tf_cc_binary since we (by deliberate design) do not depend on
# //tensorflow/core:lib.
#
# Note: to get smaller size on android for comparison, compile with:
# --copt=-fvisibility=hidden
# --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN
# --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN
native.cc_binary(
name=benchmark_name,
srcs=[benchmark_file],
testonly = testonly,
copts = tf_copts(),
linkopts = if_android(["-pie", "-s"]),
deps=[
":" + name,
"//tensorflow/compiler/aot:benchmark",
"//tensorflow/compiler/aot:runtime",
"//tensorflow/compiler/xla:executable_run_options",
"//third_party/eigen3",
] + if_android([
"//tensorflow/compiler/aot:benchmark_extra_android",
]),
tags=tags,
)
# Rule to rewrite benchmark.cc to produce the benchmark_file.
native.genrule(
name = ("gen_" + benchmark_name),
srcs = [
benchmark_main,
header_file,
],
testonly = testonly,
outs = [benchmark_file],
cmd = ("sed " + sed_replace +
" $(location " + benchmark_main + ") " +
"> $(OUTS)"),
tags = tags,
)
# The cc_benchmark rule for the generated code. This does not need the
# tf_cc_binary since we (by deliberate design) do not depend on
# //tensorflow/core:lib.
#
# Note: to get smaller size on android for comparison, compile with:
# --copt=-fvisibility=hidden
# --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN
# --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN
native.cc_binary(
name = benchmark_name,
srcs = [benchmark_file],
testonly = testonly,
copts = tf_copts(),
linkopts = if_android(["-pie", "-s"]),
deps = [
":" + name,
"//tensorflow/compiler/aot:benchmark",
"//tensorflow/compiler/xla:executable_run_options",
"//third_party/eigen3",
] + if_android([
"//tensorflow/compiler/aot:benchmark_extra_android",
]),
tags = tags,
)
def target_llvm_triple():
"""Returns the target LLVM triple to be used for compiling the target."""
# TODO(toddw): Add target_triple for other targets. For details see:
# http://llvm.org/docs/doxygen/html/Triple_8h_source.html
return select({
"//tensorflow:android_armeabi": "armv5-none-android",
"//tensorflow:android_arm": "armv7-none-android",
"//tensorflow:android_arm64": "aarch64-none-android",
"//tensorflow:android_x86": "i686-none-android",
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
"//tensorflow:darwin": "x86_64-none-darwin",
"//conditions:default": "x86_64-pc-linux",
})
"""Returns the target LLVM triple to be used for compiling the target."""
# TODO(toddw): Add target_triple for other targets. For details see:
# http://llvm.org/docs/doxygen/html/Triple_8h_source.html
return select({
"//tensorflow:android_armeabi": "armv5-none-android",
"//tensorflow:android_arm": "armv7-none-android",
"//tensorflow:android_arm64": "aarch64-none-android",
"//tensorflow:android_x86": "i686-none-android",
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
"//tensorflow:darwin": "x86_64-none-darwin",
"//conditions:default": "x86_64-pc-linux",
})

View File

@ -306,6 +306,7 @@ cc_library(
srcs = [
"build_xla_launch_ops_pass.cc",
"deadness_analysis.cc",
"deadness_analysis_internal.h",
"encapsulate_subgraphs_pass.cc",
"mark_for_compilation_pass.cc",
],
@ -377,11 +378,39 @@ tf_cc_test(
],
)
tf_cc_test(
name = "deadness_analysis_test",
size = "small",
srcs = [
"deadness_analysis_internal.h",
"deadness_analysis_test.cc",
],
deps = [
":common",
":compilation_passes",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/compiler/jit/kernels:xla_launch_op",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cc_test(
name = "compilation_passes_test",
size = "small",
srcs = [
"deadness_analysis_test.cc",
"encapsulate_subgraphs_pass_test.cc",
"mark_for_compilation_pass_test.cc",
],

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/gtl/flatset.h"
@ -151,7 +152,11 @@ class SymbolPredicate : public Predicate {
tensor_id_(std::move(tensor_id)),
must_be_true_(must_be_true) {}
string ToString() const override { return tensor_id_.ToString(); }
string ToString() const override {
return must_be_true() ? strings::StrCat("*", tensor_id_.ToString())
: tensor_id_.ToString();
}
Kind kind() const override { return Kind::kSymbol; }
// If `must_be_true()` is true this SymbolPredicate represents the proposition
@ -348,6 +353,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
Status Populate();
bool HasInputsWithMismatchingDeadness(const Node& node) override;
void Print() const override;
gtl::FlatMap<TensorId, string, TensorId::Hasher> PredicateMapAsString() const;
private:
enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
@ -563,4 +569,24 @@ DeadnessAnalysis::~DeadnessAnalysis() {}
return Status::OK();
}
gtl::FlatMap<TensorId, string, TensorId::Hasher>
DeadnessAnalysisImpl::PredicateMapAsString() const {
gtl::FlatMap<TensorId, string, TensorId::Hasher> result;
std::vector<TensorId> tensor_ids;
for (const auto& kv_pair : predicate_map_) {
CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
}
return result;
}
namespace deadness_analysis_internal {
Status ComputePredicates(const Graph& graph,
PredicateMapTy* out_predicate_map) {
DeadnessAnalysisImpl impl(&graph);
TF_RETURN_IF_ERROR(impl.Populate());
*out_predicate_map = impl.PredicateMapAsString();
return Status::OK();
}
} // namespace deadness_analysis_internal
} // namespace tensorflow

View File

@ -0,0 +1,32 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_
#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace tensorflow {
namespace deadness_analysis_internal {
// Returns a map describing the predicate each Tensor was mapped to. For
// testing purposes only.
using PredicateMapTy = gtl::FlatMap<TensorId, string, TensorId::Hasher>;
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map);
} // namespace deadness_analysis_internal
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -439,5 +440,28 @@ TEST(DeadnessAnalysisTest, RecvVsSwitch) {
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*logical_and.node()));
}
TEST(DeadnessAnalysisTest, RecvVsSwitchText) {
// Demonstrates why we need the must_be_true bit on SymbolP.
Scope root = Scope::NewRootScope().ExitOnError();
Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
0, "receiver");
Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
ops::Switch sw(root.WithOpName("switch"), value, recv);
Output logical_and =
ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
deadness_analysis_internal::PredicateMapTy predicate_map;
TF_ASSERT_OK(deadness_analysis_internal::ComputePredicates(*root.graph(),
&predicate_map));
TensorId logical_and_output_0 = {logical_and.node()->name(),
Graph::kControlSlot};
EXPECT_EQ(predicate_map[logical_and_output_0], "(recv:0 & *recv:0)");
}
} // namespace
} // namespace tensorflow

View File

@ -462,6 +462,7 @@ Status MarkForCompilationPass::Run(
VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit;
VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit;
const FunctionLibraryDefinition* fld = options.flib_def;
std::unique_ptr<DeadnessAnalysis> deadness;

View File

@ -258,6 +258,7 @@ Status XlaCompilationCache::CompileImpl(
xla::LocalExecutable** executable,
const XlaCompiler::CompileOptions* compile_options,
bool compile_single_op) {
CHECK_NE(executable, nullptr);
VLOG(1) << "XlaCompilationCache::Compile " << DebugString();
if (VLOG_IS_ON(2)) {
@ -329,18 +330,15 @@ Status XlaCompilationCache::CompileImpl(
compile_options ? *compile_options : XlaCompiler::CompileOptions(),
function, args, &entry->compilation_result);
}
TF_RETURN_IF_ERROR(entry->compilation_status);
CHECK_EQ(entry->executable.get(), nullptr);
entry->compilation_status =
BuildExecutable(options, entry->compilation_result, &entry->executable);
}
TF_RETURN_IF_ERROR(entry->compilation_status);
*compilation_result = &entry->compilation_result;
if (entry->compilation_status.ok() && executable) {
if (entry->executable == nullptr) {
entry->compilation_status = BuildExecutable(
options, entry->compilation_result, &entry->executable);
}
*executable = entry->executable.get();
}
Status status = entry->compilation_status;
return status;
*executable = entry->executable.get();
return Status::OK();
}
} // namespace tensorflow

View File

@ -400,6 +400,21 @@ class EagerFunctionTest(xla_test.XLATestCase):
self.assertEqual(75, y.numpy())
self.assertEqual(30, dy.numpy())
def testGradientTapeInDefun(self):
with self.test_scope():
v0 = resource_variable_ops.ResourceVariable(5.0)
@function.defun
def f():
x = constant_op.constant(1.0)
with backprop.GradientTape() as tape:
y = v0 * x
dy = tape.gradient(y, v0)
return dy
dy = f()
self.assertEqual(1.0, dy.numpy())
def testSliceInDefun(self):
with self.test_scope():

View File

@ -62,6 +62,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
@ -101,6 +102,9 @@ class OpTestBuilder {
OpTestBuilder& RandomInput(DataType type);
OpTestBuilder& RandomInput(DataType type, std::vector<int64> dims);
// As RandomInput but the values are unique.
OpTestBuilder& RandomUniqueInput(DataType type, std::vector<int64> dims);
// Sets an attribute.
template <class T>
OpTestBuilder& Attr(StringPiece attr_name, T&& value);
@ -126,6 +130,7 @@ class OpTestBuilder {
DataType type = DT_INVALID;
bool has_dims = false;
bool needs_unique_values = false;
std::vector<int64> dims;
};
@ -167,6 +172,18 @@ OpTestBuilder& OpTestBuilder::RandomInput(DataType type,
return *this;
}
OpTestBuilder& OpTestBuilder::RandomUniqueInput(DataType type,
std::vector<int64> dims) {
VLOG(1) << "Adding input: " << type << " " << TensorShape(dims).DebugString();
InputDescription input;
input.type = type;
input.has_dims = true;
input.needs_unique_values = true;
input.dims = std::move(dims);
inputs_.push_back(input);
return *this;
}
template <class T>
OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, T&& value) {
AddNodeAttr(attr_name, std::forward<T>(value), &node_def_);
@ -289,7 +306,8 @@ class OpTest : public ::testing::Test {
// Returns a tensor filled with random but "reasonable" values from the middle
// of the type's range. If the shape is omitted, a random shape is used.
// TODO(phawkins): generalize this code to a caller-supplied distribution.
Tensor RandomTensor(DataType dtype, gtl::ArraySlice<int64> shape);
Tensor RandomTensor(DataType dtype, bool needs_unique_values,
gtl::ArraySlice<int64> shape);
Tensor RandomTensor(DataType dtype);
// Like RandomTensor, but uses values >= 0.
@ -432,49 +450,90 @@ std::vector<int64> OpTest::RandomDims(int min_rank, int max_rank,
return dims;
}
Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice<int64> shape) {
Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
gtl::ArraySlice<int64> shape) {
Tensor tensor(dtype, TensorShape(shape));
switch (dtype) {
case DT_FLOAT: {
gtl::FlatSet<float> already_generated;
std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
test::FillFn<float>(&tensor, [this, &distribution](int i) -> float {
return distribution(generator());
test::FillFn<float>(&tensor, [&](int i) -> float {
float generated;
do {
generated = distribution(generator());
} while (needs_unique_values &&
!already_generated.insert(generated).second);
return generated;
});
break;
}
case DT_DOUBLE: {
gtl::FlatSet<double> already_generated;
std::uniform_real_distribution<double> distribution(-1.0, 1.0);
test::FillFn<double>(&tensor, [this, &distribution](int i) -> double {
return distribution(generator());
test::FillFn<double>(&tensor, [&](int i) -> double {
double generated;
do {
generated = distribution(generator());
} while (needs_unique_values &&
!already_generated.insert(generated).second);
return generated;
});
break;
}
case DT_COMPLEX64: {
gtl::FlatSet<std::pair<float, float>> already_generated;
std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
test::FillFn<complex64>(&tensor, [this, &distribution](int i) {
return complex64(distribution(generator()), distribution(generator()));
test::FillFn<complex64>(&tensor, [&](int i) {
complex64 generated;
do {
generated =
complex64(distribution(generator()), distribution(generator()));
} while (
needs_unique_values &&
!already_generated
.insert(std::make_pair(generated.real(), generated.imag()))
.second);
return generated;
});
break;
}
case DT_INT32: {
gtl::FlatSet<int32> already_generated;
std::uniform_int_distribution<int32> distribution(-(1 << 20), 1 << 20);
test::FillFn<int32>(&tensor, [this, &distribution](int i) -> int32 {
return distribution(generator());
test::FillFn<int32>(&tensor, [&](int i) -> int32 {
int32 generated;
do {
generated = distribution(generator());
} while (needs_unique_values &&
!already_generated.insert(generated).second);
return generated;
});
break;
}
case DT_INT64: {
gtl::FlatSet<int64> already_generated;
std::uniform_int_distribution<int64> distribution(-(1LL << 40),
1LL << 40);
test::FillFn<int64>(&tensor, [this, &distribution](int i) -> int64 {
return distribution(generator());
test::FillFn<int64>(&tensor, [&](int i) -> int64 {
int64 generated;
do {
generated = distribution(generator());
} while (needs_unique_values &&
!already_generated.insert(generated).second);
return generated;
});
break;
}
case DT_BOOL: {
gtl::FlatSet<bool> already_generated;
std::bernoulli_distribution distribution;
test::FillFn<bool>(&tensor, [this, &distribution](int i) -> bool {
return distribution(generator());
test::FillFn<bool>(&tensor, [&](int i) -> bool {
bool generated;
do {
generated = distribution(generator());
} while (needs_unique_values &&
!already_generated.insert(generated).second);
return generated;
});
break;
}
@ -485,7 +544,7 @@ Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice<int64> shape) {
}
Tensor OpTest::RandomTensor(DataType dtype) {
return RandomTensor(dtype, RandomDims());
return RandomTensor(dtype, /*needs_unique_values=*/false, RandomDims());
}
Tensor OpTest::RandomNonNegativeTensor(DataType dtype,
@ -761,7 +820,8 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose(
VLOG(1) << "Ignoring oversize dims.";
return kInvalid;
}
input_tensors.push_back(RandomTensor(input.type, dims));
input_tensors.push_back(
RandomTensor(input.type, input.needs_unique_values, dims));
}
VLOG(1) << "Input: " << input_tensors.back().DebugString();
}
@ -960,7 +1020,7 @@ TEST_F(OpTest, ArgMax) {
std::uniform_int_distribution<int32>(-num_dims, num_dims)(generator());
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("ArgMax")
.RandomInput(DT_FLOAT, dims)
.RandomUniqueInput(DT_FLOAT, dims)
.Input(test::AsScalar<int32>(reduce_dim))
.Attr("T", DT_FLOAT)
.Attr("Tidx", DT_INT32)
@ -976,7 +1036,7 @@ TEST_F(OpTest, ArgMin) {
std::uniform_int_distribution<int32>(-num_dims, num_dims)(generator());
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("ArgMin")
.RandomInput(DT_FLOAT, dims)
.RandomUniqueInput(DT_FLOAT, dims)
.Input(test::AsScalar<int32>(reduce_dim))
.Attr("T", DT_FLOAT)
.Attr("Tidx", DT_INT32)

View File

@ -91,6 +91,18 @@ cc_library(
],
)
cc_library(
name = "cpu_function_runtime",
srcs = ["cpu_function_runtime.cc"],
hdrs = ["cpu_function_runtime.h"],
deps = [
# Keep dependencies to a minimum here; this library is used in every AOT
# binary produced by tfcompile.
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/core:framework_lite",
],
)
cc_library(
name = "xla_compiled_cpu_function",
srcs = ["xla_compiled_cpu_function.cc"],
@ -99,12 +111,23 @@ cc_library(
deps = [
# Keep dependencies to a minimum here; this library is used in every AOT
# binary produced by tfcompile.
"//tensorflow/compiler/aot:runtime",
":cpu_function_runtime",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/core:framework_lite",
],
)
tf_cc_test(
name = "cpu_function_runtime_test",
srcs = ["cpu_function_runtime_test.cc"],
deps = [
":cpu_function_runtime",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "xla_jit_compiled_cpu_function",
srcs = ["xla_jit_compiled_cpu_function.cc"],

View File

@ -1,4 +1,4 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,22 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/aot/runtime.h"
#include <stdlib.h>
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
namespace tensorflow {
namespace tfcompile {
namespace runtime {
namespace {
// Inline memory allocation routines here, because depending on '//base' brings
// in libraries which use c++ streams, which adds considerable code size on
// android.
inline void* aligned_malloc(size_t size, int minimum_alignment) {
void* aligned_malloc(size_t size, int minimum_alignment) {
#if defined(__ANDROID__) || defined(OS_ANDROID) || defined(OS_CYGWIN)
return memalign(minimum_alignment, size);
#elif defined(_WIN32)
@ -47,7 +41,7 @@ inline void* aligned_malloc(size_t size, int minimum_alignment) {
#endif
}
inline void aligned_free(void* aligned_memory) {
void aligned_free(void* aligned_memory) {
#if defined(_WIN32)
_aligned_free(aligned_memory);
#else
@ -58,13 +52,13 @@ inline void aligned_free(void* aligned_memory) {
size_t align_to(size_t n, size_t align) {
return (((n - 1) / align) + 1) * align;
}
} // namespace
size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n) {
namespace cpu_function_runtime {
size_t AlignedBufferBytes(const intptr_t* sizes, size_t n) {
size_t total = 0;
for (size_t i = 0; i < n; ++i) {
if (sizes[i] != -1) {
if (sizes[i] > 0) {
total += align_to(sizes[i], kAlign);
}
}
@ -73,7 +67,7 @@ size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n) {
void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
bool annotate_initialized) {
const size_t total = aligned_buffer_bytes(sizes, n);
const size_t total = AlignedBufferBytes(sizes, n);
void* contiguous = nullptr;
if (total > 0) {
contiguous = aligned_malloc(total, kAlign);
@ -85,7 +79,9 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
}
uintptr_t pos = reinterpret_cast<uintptr_t>(contiguous);
for (size_t i = 0; i < n; ++i) {
if (sizes[i] == -1) {
if (sizes[i] < 0) {
// bufs[i] is either a constant, an entry parameter or a thread local
// allocation.
bufs[i] = nullptr;
} else {
bufs[i] = reinterpret_cast<void*>(pos);
@ -100,7 +96,5 @@ void FreeContiguous(void* contiguous) {
aligned_free(contiguous);
}
}
} // namespace runtime
} // namespace tfcompile
} // namespace cpu_function_runtime
} // namespace tensorflow

View File

@ -1,4 +1,4 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,25 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file contains utilities to make it easier to invoke functions generated
// by tfcompile. Usage of these utilities is optional.
#ifndef TENSORFLOW_COMPILER_AOT_RUNTIME_H_
#define TENSORFLOW_COMPILER_AOT_RUNTIME_H_
#ifndef TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_
#define TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace tfcompile {
namespace runtime {
namespace cpu_function_runtime {
// Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment.
static constexpr size_t kAlign = 64;
constexpr size_t kAlign = 64;
// aligned_buffer_bytes returns the sum of each size in `sizes`, skipping -1
// values. There are `n` entries in `sizes`. Each buffer is aligned to kAlign
// byte boundaries.
size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n);
// AlignedBufferBytes returns the sum of each size in `sizes`, skipping -1
// values. There are `n` entries in `sizes`. Each buffer is aligned to
// kAlign byte boundaries.
size_t AlignedBufferBytes(const intptr_t* sizes, size_t n);
// MallocContiguousBuffers allocates buffers for use by the entry point
// generated by tfcompile. `sizes` is an array of byte sizes for each buffer,
@ -41,8 +37,8 @@ size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n);
// temporary buffers.
//
// A single contiguous block of memory is allocated, and portions of it are
// parceled out into `bufs`, which must have space for `n` entries. Returns the
// head of the allocated contiguous block, which should be passed to
// parceled out into `bufs`, which must have space for `n` entries. Returns
// the head of the allocated contiguous block, which should be passed to
// FreeContiguous when the buffers are no longer in use.
void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
bool annotate_initialized);
@ -50,9 +46,7 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
// FreeContiguous frees the contiguous block of memory allocated by
// MallocContiguousBuffers.
void FreeContiguous(void* contiguous);
} // namespace runtime
} // namespace tfcompile
} // namespace cpu_function_runtime
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_AOT_RUNTIME_H_
#endif // TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_

View File

@ -13,39 +13,37 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/aot/runtime.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace tfcompile {
namespace runtime {
namespace {
TEST(Runtime, AlignmentValue) {
TEST(XlaCompiledCpuFunctionTest, AlignmentValue) {
// We've chosen 64 byte alignment for the tfcompile runtime to mimic the
// regular tensorflow allocator, which was chosen to play nicely with Eigen.
// The tfcompile runtime also has a requirement that comes from the xla
// generated code, on the relation: buffer_size >= 16 ? 2 * sizeof(void*) : 8
// So any value that we choose must abide by that constraint as well.
EXPECT_EQ(kAlign, Allocator::kAllocatorAlignment);
EXPECT_EQ(cpu_function_runtime::kAlign, Allocator::kAllocatorAlignment);
}
TEST(Runtime, AlignedBufferBytes) {
EXPECT_EQ(aligned_buffer_bytes(nullptr, 0), 0);
TEST(XlaCompiledCpuFunctionTest, AlignedBufferBytes) {
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(nullptr, 0), 0);
static constexpr intptr_t sizesA[1] = {-1};
EXPECT_EQ(aligned_buffer_bytes(sizesA, 1), 0);
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesA, 1), 0);
static constexpr intptr_t sizesB[1] = {3};
EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 64);
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesB, 1), 64);
static constexpr intptr_t sizesC[1] = {32};
EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 64);
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesC, 1), 64);
static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 320);
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesD, 7), 320);
}
void* add_ptr(void* base, uintptr_t delta) {
@ -56,48 +54,49 @@ void* add_ptr(void* base, uintptr_t delta) {
// expected nullptrs, and write to each byte of allocated memory. We rely on
// the leak checker to tell us if there's an inconsistency between malloc and
// free. We also check the contiguous property.
TEST(Runtime, MallocFreeContiguousBuffers) {
TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
// Test empty sizes.
void* base = MallocContiguousBuffers(nullptr, 0, nullptr, false);
void* base =
cpu_function_runtime::MallocContiguousBuffers(nullptr, 0, nullptr, false);
EXPECT_EQ(base, nullptr);
FreeContiguous(base);
cpu_function_runtime::FreeContiguous(base);
// Test non-empty sizes with 0 sum.
static constexpr intptr_t sizesA[1] = {-1};
void* bufA[1];
base = MallocContiguousBuffers(sizesA, 1, bufA, false);
base = cpu_function_runtime::MallocContiguousBuffers(sizesA, 1, bufA, false);
EXPECT_EQ(base, nullptr);
EXPECT_EQ(bufA[0], nullptr);
FreeContiguous(base);
cpu_function_runtime::FreeContiguous(base);
// Test non-empty sizes with non-0 sum.
static constexpr intptr_t sizesB[1] = {3};
void* bufB[1];
base = MallocContiguousBuffers(sizesB, 1, bufB, false);
base = cpu_function_runtime::MallocContiguousBuffers(sizesB, 1, bufB, false);
EXPECT_NE(base, nullptr);
EXPECT_EQ(bufB[0], add_ptr(base, 0));
char* bufB0_bytes = static_cast<char*>(bufB[0]);
bufB0_bytes[0] = 'A';
bufB0_bytes[1] = 'B';
bufB0_bytes[2] = 'C';
FreeContiguous(base);
cpu_function_runtime::FreeContiguous(base);
// Test non-empty sizes with non-0 sum, and annotate_initialized.
static constexpr intptr_t sizesC[1] = {3};
void* bufC[1];
base = MallocContiguousBuffers(sizesC, 1, bufC, true);
base = cpu_function_runtime::MallocContiguousBuffers(sizesC, 1, bufC, true);
EXPECT_NE(base, nullptr);
EXPECT_EQ(bufC[0], add_ptr(base, 0));
char* bufC0_bytes = static_cast<char*>(bufC[0]);
bufC0_bytes[0] = 'A';
bufC0_bytes[1] = 'B';
bufC0_bytes[2] = 'C';
FreeContiguous(base);
cpu_function_runtime::FreeContiguous(base);
// Test mixed sizes.
static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
void* bufD[7];
base = MallocContiguousBuffers(sizesD, 7, bufD, false);
base = cpu_function_runtime::MallocContiguousBuffers(sizesD, 7, bufD, false);
EXPECT_NE(base, nullptr);
EXPECT_EQ(bufD[0], add_ptr(base, 0));
EXPECT_EQ(bufD[1], nullptr);
@ -115,10 +114,8 @@ TEST(Runtime, MallocFreeContiguousBuffers) {
}
}
}
FreeContiguous(base);
cpu_function_runtime::FreeContiguous(base);
}
} // namespace
} // namespace runtime
} // namespace tfcompile
} // namespace tensorflow

View File

@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include <cassert>
#include "tensorflow/compiler/aot/runtime.h"
namespace tensorflow {
@ -26,20 +26,29 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
result_index_(static_data.result_index),
args_(new void*[static_data.num_args]),
temps_(new void*[static_data.num_temps]),
arg_index_to_temp_index_(new int32[static_data.num_args]),
num_args_(static_data.num_args),
arg_names_(static_data.arg_names),
result_names_(static_data.result_names),
program_shape_(static_data.program_shape),
hlo_profile_printer_data_(static_data.hlo_profile_printer_data) {
// Allocate arg and temp buffers.
if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) {
alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
alloc_args_ = cpu_function_runtime::MallocContiguousBuffers(
static_data.arg_sizes, static_data.num_args, args_,
/*annotate_initialized=*/false);
}
alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
alloc_temps_ = cpu_function_runtime::MallocContiguousBuffers(
static_data.temp_sizes, static_data.num_temps, temps_,
/*annotate_initialized=*/true);
for (int i = 0; i < static_data.num_temps; i++) {
if (static_data.temp_sizes[i] < -1) {
int32 param_number = -(static_data.temp_sizes[i] + 2);
arg_index_to_temp_index_[param_number] = i;
}
}
// If Hlo profiling is enabled the generated code expects an appropriately
// sized buffer to be passed in as the last argument. If Hlo profiling is
// disabled the last function argument is still present in the function
@ -50,11 +59,24 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
}
}
bool XlaCompiledCpuFunction::Run() {
// Propagate pointers to the argument buffers into the temps array. Code
// generated by XLA discovers the incoming argument pointers from the temps
// array.
for (int32 i = 0; i < num_args_; i++) {
temps_[arg_index_to_temp_index_[i]] = args_[i];
}
raw_function_(temps_[result_index_], &run_options_, nullptr, temps_,
profile_counters_);
return true;
}
XlaCompiledCpuFunction::~XlaCompiledCpuFunction() {
tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_);
tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_);
cpu_function_runtime::FreeContiguous(alloc_args_);
cpu_function_runtime::FreeContiguous(alloc_temps_);
delete[] args_;
delete[] temps_;
delete[] arg_index_to_temp_index_;
delete[] profile_counters_;
}

View File

@ -60,9 +60,19 @@ class XlaCompiledCpuFunction {
// The raw function to call.
RawFunction raw_function;
// Cardinality and sizes of arg and temp buffers.
// Cardinality and size of arg buffers.
const intptr_t* arg_sizes = nullptr;
size_t num_args = 0;
// Cardinality and size of temp buffers.
//
// If temp_sizes[i] >= 0 then the i'th temp is a regular temporary buffer.
//
// If temp_sizes[i] == -1 then the i'th temp is a constant buffer. The
// corresponding entry in the temp buffer array needs to be set to null.
//
// If temp_sizes[i] < -1 then the i'th temp is the entry parameter
// -(temp_sizes[i] + 2).
const intptr_t* temp_sizes = nullptr;
size_t num_temps = 0;
@ -113,11 +123,7 @@ class XlaCompiledCpuFunction {
// Runs the computation, with inputs read from arg buffers, and outputs
// written to result buffers. Returns true on success and false on failure.
bool Run() {
raw_function_(temps_[result_index_], &run_options_,
const_cast<const void**>(args_), temps_, profile_counters_);
return true;
}
bool Run();
// Returns the error message from the previous failed Run call.
//
@ -224,6 +230,17 @@ class XlaCompiledCpuFunction {
void** args_ = nullptr;
void** temps_ = nullptr;
// Argument i needs to be placed in temps_[arg_index_to_temp_index_[i]] for
// XLA generated code to be able to find it.
//
// For now we need to keep around the args_ array because there is code that
// depends on args() returning a void**. However, in the future we may remove
// args_ in favor of using temps_ as the sole storage for the arguments.
int32* arg_index_to_temp_index_;
// The number of incoming arguments.
int32 num_args_;
// Backing memory for individual arg and temp buffers.
void* alloc_args_ = nullptr;
void* alloc_temps_ = nullptr;

View File

@ -58,11 +58,15 @@ xla::StatusOr<std::vector<intptr_t>> ComputeTempSizes(
std::vector<intptr_t> temp_sizes;
temp_sizes.reserve(allocations.size());
for (const xla::BufferAllocation& allocation : allocations) {
// Callers don't allocate temporary buffers for parameters. Nor for
// thread-local buffers, which are lowered to alloca.
if (allocation.is_entry_computation_parameter() ||
allocation.is_thread_local()) {
if (allocation.is_constant() || allocation.is_thread_local()) {
// Constants are lowered to globals. Thread locals are lowered to
// allocas.
temp_sizes.push_back(-1);
} else if (allocation.is_entry_computation_parameter()) {
// Entry computation parameters need some preprocessing in
// XlaCompiledCpuFunction::Run. See the comment on
// XlaCompiledCpuFunction::StaticData::temp_sizes.
temp_sizes.push_back(-allocation.parameter_number() - 2);
} else {
temp_sizes.push_back(allocation.size());
}

View File

@ -56,7 +56,7 @@ ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
// Performs a single round of the Threefry2x32 algorithm, with a rotation
// amount 'rotation'.
auto round = [builder](ThreeFry2x32State v, int rotation) {
auto round = [](ThreeFry2x32State v, int rotation) {
v[0] = v[0] + v[1];
v[1] = RotateLeftS32(v[1], rotation);
v[1] = v[0] ^ v[1];

View File

@ -98,14 +98,13 @@ std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie(
<< "Computation should have progran shape.";
auto program_shape = computation.proto().program_shape();
// For every (unbound) parameter that the computation wants, we manufacture
// some arbitrary data so that we can invoke the computation.
std::vector<std::unique_ptr<GlobalData>> fake_arguments;
for (const Shape& parameter : program_shape.parameters()) {
fake_arguments.push_back(MakeFakeDataOrDie(parameter, client));
}
return fake_arguments;
// Create and run a program which produces a tuple with one element per
// parameter, then return the tuple's constituent buffers.
std::vector<Shape> param_shapes(program_shape.parameters().begin(),
program_shape.parameters().end());
auto fake_input_tuple =
MakeFakeDataOrDie(ShapeUtil::MakeTupleShape(param_shapes), client);
return client->DeconstructTuple(*fake_input_tuple).ValueOrDie();
}
} // namespace xla

View File

@ -101,11 +101,14 @@ Status LocalExecutable::ValidateExecutionOptions(
}
}
// Verify that the device the executable was built for is equivalent to the
// device it will run on.
int run_device_ordinal = run_options.device_ordinal() == -1
? backend_->default_device_ordinal()
: run_options.device_ordinal();
// Verify that the device the executable was built for is equivalent
// to the device it will run on.
int run_device_ordinal = run_options.device_ordinal();
if (run_device_ordinal == -1) {
run_device_ordinal = run_options.stream() != nullptr
? run_options.stream()->parent()->device_ordinal()
: backend_->default_device_ordinal();
}
TF_ASSIGN_OR_RETURN(bool devices_equivalent,
backend_->devices_equivalent(
run_device_ordinal, build_options_.device_ordinal()));

View File

@ -1635,6 +1635,32 @@ XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices,
});
}
XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices,
const XlaOp& updates,
const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
TF_ASSIGN_OR_RETURN(const Shape& scatter_indices_shape,
GetShape(scatter_indices));
TF_ASSIGN_OR_RETURN(const Shape& updates_shape, GetShape(updates));
TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
update_computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferScatterShape(
input_shape, scatter_indices_shape, updates_shape,
to_apply_shape, dimension_numbers));
*instr.mutable_scatter_dimension_numbers() = dimension_numbers;
AddCalledComputation(update_computation, &instr);
return AddInstruction(std::move(instr), HloOpcode::kScatter,
{input, scatter_indices, updates});
});
}
XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand,
const XlaComputation& true_computation,
const XlaOp& false_operand,
@ -1681,7 +1707,7 @@ XlaOp XlaBuilder::Reduce(
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferReduceShape(
operand_shape, init_shape, dimensions_to_reduce,
{&operand_shape, &init_shape}, dimensions_to_reduce,
called_program_shape));
for (int64 dim : dimensions_to_reduce) {
@ -2803,6 +2829,13 @@ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
window_bounds);
}
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
const XlaOp& updates, const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers) {
return input.builder()->Scatter(input, scatter_indices, updates,
update_computation, dimension_numbers);
}
void Send(const XlaOp& operand, const ChannelHandle& handle) {
return operand.builder()->Send(operand, handle);
}

View File

@ -857,6 +857,11 @@ class XlaBuilder {
const GatherDimensionNumbers& dimension_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
// Enqueues a Scatter node onto the computation.
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
const XlaOp& updates, const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers);
// Enqueues a Send node onto the computation for device-to-device
// communication, to send the given operand to a Recv instruction that shares
// the same channel handle.
@ -1296,6 +1301,10 @@ class XlaBuilder {
friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
const GatherDimensionNumbers& dimension_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
const XlaOp& updates,
const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers);
friend void Send(const XlaOp& operand, const ChannelHandle& handle);
friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
const ChannelHandle& handle);
@ -1977,6 +1986,11 @@ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
const GatherDimensionNumbers& dimension_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
// Enqueues a Scatter node onto the computation.
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
const XlaOp& updates, const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers);
// Enqueues a Send node onto the computation for device-to-device
// communication. This operation sends the given operand to
// a Recv instruction in a different computation that shares the same channel

View File

@ -297,7 +297,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
shape.layout().padded_dimensions_size() == 0) {
return false;
}
CHECK(IsDenseArray(shape));
CHECK(IsDenseArray(shape)) << shape.ShortDebugString();
CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size());
for (int64 i = 0; i < shape.dimensions_size(); ++i) {
if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) {

View File

@ -36,7 +36,6 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::strings::Printf;
using tensorflow::strings::StrCat;
namespace xla {

View File

@ -134,8 +134,7 @@ void MetricTableReport::AppendHeader() {
void MetricTableReport::AppendCategoryTable() {
const std::vector<Category> categories = MakeCategories(&entries_);
AppendLine("********** categories table **********");
AppendLine("The left hand side numbers are ", metric_name_, ".");
AppendLine("********** categories table for ", metric_name_, " **********");
AppendLine();
double metric_sum = UnaccountedMetric();
@ -185,8 +184,8 @@ void MetricTableReport::AppendCategoryTable() {
}
void MetricTableReport::AppendEntryTable() {
AppendLine("********** ", entry_name_, " table **********");
AppendLine("The left hand side numbers are ", metric_name_, ".");
AppendLine("********** ", entry_name_, " table for ", metric_name_,
" **********");
AppendLine();
double metric_sum = UnaccountedMetric();

View File

@ -10,6 +10,8 @@ py_library(
srcs = ["types.py"],
deps = [
"//tensorflow/compiler/xla:xla_data_proto_py",
"//tensorflow/python:dtypes",
"//tensorflow/python:platform",
"//third_party/py/numpy",
],
)

View File

@ -23,6 +23,7 @@ import collections
import numpy as _np # Avoids becoming a part of public Tensorflow API.
from tensorflow.compiler.xla import xla_data_pb2
from tensorflow.python.framework import dtypes
# Records corresponsence between a XLA primitive type and Python/Numpy types.
#
@ -40,6 +41,12 @@ TypeConversionRecord = collections.namedtuple('TypeConversionRecord', [
# Maps from XLA primitive types to TypeConversionRecord.
MAP_XLA_TYPE_TO_RECORD = {
xla_data_pb2.BF16:
TypeConversionRecord(
primitive_type=xla_data_pb2.BF16,
numpy_dtype=dtypes.bfloat16.as_numpy_dtype,
literal_field_name='bf16s',
literal_field_type=float),
xla_data_pb2.F16:
TypeConversionRecord(
primitive_type=xla_data_pb2.F16,

View File

@ -2006,7 +2006,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
// Builds a convolution from <options> and runs algebraic simplification on
// the computation. Returns a string description of the result of
// simplification.
auto build_and_simplify = [&options, this]() -> string {
auto build_and_simplify = [&options]() -> string {
HloComputation::Builder b(TestName());
Window window;

View File

@ -109,11 +109,11 @@ Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
ResolveInternal(data));
for (const auto& shaped_buffer : replicated_buffers) {
std::vector<ShapeIndex> shape_indices;
ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(),
[this, &shape_indices](const Shape& /*subshape*/,
const ShapeIndex& index) {
shape_indices.push_back(index);
});
ShapeUtil::ForEachSubshape(
shaped_buffer->on_device_shape(),
[&shape_indices](const Shape& /*subshape*/, const ShapeIndex& index) {
shape_indices.push_back(index);
});
for (const ShapeIndex& index : shape_indices) {
TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index),
shaped_buffer->device_ordinal()));

View File

@ -137,9 +137,9 @@ ENTRY entry {
if (instruction->opcode() == HloOpcode::kParameter) {
continue;
}
ASSERT_TRUE(instruction->has_sharding());
TF_ASSERT_OK_AND_ASSIGN(int device, instruction->sharding().UniqueDevice());
EXPECT_EQ(device, 1);
auto device = instruction->sharding_unique_device();
ASSERT_TRUE(device);
EXPECT_EQ(*device, 1);
}
}

View File

@ -877,8 +877,8 @@ Status BufferAssigner::AssignBuffersForComputation(
// important reuse case where an elementwise instruction reuses one of its
// operand's buffer. This improves locality.
std::sort(sorted_buffers.begin(), sorted_buffers.end(),
[this, has_sequential_order, &liveness, &post_order_position,
assignment](const LogicalBuffer* a, const LogicalBuffer* b) {
[has_sequential_order, &liveness, &post_order_position, assignment](
const LogicalBuffer* a, const LogicalBuffer* b) {
// Primary sort is by decreasing buffer size.
const int64 a_size = assignment->buffer_size_(*a);
const int64 b_size = assignment->buffer_size_(*b);
@ -1441,9 +1441,9 @@ void BufferAssigner::BuildColocatedBufferSets(
const HloInstruction* while_hlo = instruction;
ShapeUtil::ForEachSubshape(
while_hlo->shape(),
[this, while_hlo, &points_to_analysis, &buffer_liveness,
buffer_size, computation, colocated_buffer_sets](
const Shape& /*subshape*/, const ShapeIndex& index) {
[this, while_hlo, &points_to_analysis, buffer_size,
colocated_buffer_sets](const Shape& /*subshape*/,
const ShapeIndex& index) {
std::vector<const LogicalBuffer*> colocated_set;
// Add while.init.
AddBufferToColocatedSet(while_hlo->operand(0), index,

View File

@ -156,9 +156,26 @@ std::unique_ptr<llvm::MemoryBuffer> CompilerFunctor::operator()(
target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream);
codegen_passes.run(module);
// Construct ObjectFile from machine code buffer.
return std::unique_ptr<llvm::MemoryBuffer>(
std::unique_ptr<llvm::MemoryBuffer> memory_buffer(
new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer)));
if (VLOG_IS_ON(2)) {
llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> obj_file =
llvm::object::ObjectFile::createObjectFile(*memory_buffer);
if (obj_file) {
StatusOr<DisassemblerResult> disasm_result =
disassembler_->DisassembleObjectFile(*obj_file.get());
if (disasm_result.ok()) {
XLA_VLOG_LINES(2, disasm_result.ValueOrDie().text);
} else {
LOG(WARNING) << "Could not disassemble object file!";
}
} else {
LOG(WARNING) << "Could convert memory buffer to object file!";
}
}
return memory_buffer;
}
static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() {

View File

@ -840,18 +840,29 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
BufferSizes buffer_sizes;
for (const BufferAllocation& allocation : assignment->Allocations()) {
// Callers don't need to allocate temporary buffers for parameters.
if (allocation.is_entry_computation_parameter() ||
allocation.is_constant()) {
buffer_sizes.push_back(-1);
continue;
}
// Callers don't need to allocate anything for thread-local temporary
// buffers. They are lowered to allocas.
if (allocation.is_thread_local()) {
buffer_sizes.push_back(-1);
continue;
}
// Callers don't need to allocate anything for constant buffers. They are
// lowered to globals.
if (allocation.is_constant()) {
buffer_sizes.push_back(-1);
continue;
}
// Callers don't need to allocate anything for entry computation buffers,
// but they do need to stash the pointer to the entry computation buffer
// in the temp buffer table. See the comment on
// XlaCompiledCpuFunction::StaticData::temp_sizes.
if (allocation.is_entry_computation_parameter()) {
buffer_sizes.push_back(-allocation.parameter_number() - 2);
continue;
}
buffer_sizes.push_back(allocation.size());
}

View File

@ -69,12 +69,19 @@ CpuExecutable::CpuExecutable(
// guarded by the mutex.
compute_function_ =
reinterpret_cast<ComputeFunctionType>(cantFail(sym.getAddress()));
VLOG(1) << "compute_function_ at address "
<< reinterpret_cast<void*>(compute_function_);
}
Status CpuExecutable::AllocateBuffers(
StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
std::vector<OwningDeviceMemory>>>
CpuExecutable::CreateTempArray(
DeviceMemoryAllocator* memory_allocator, int device_ordinal,
std::vector<OwningDeviceMemory>* buffers) {
CHECK_EQ(buffers->size(), assignment_->Allocations().size());
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
std::vector<se::DeviceMemoryBase> unowning_buffers(
assignment_->Allocations().size());
std::vector<OwningDeviceMemory> owning_buffers(
assignment_->Allocations().size());
VLOG(3) << "Allocating " << assignment_->Allocations().size()
<< " allocations for module " << module().name();
for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
@ -84,6 +91,8 @@ Status CpuExecutable::AllocateBuffers(
VLOG(3) << allocation.ToString();
if (allocation.is_entry_computation_parameter()) {
unowning_buffers[i] = arguments[allocation.parameter_number()]->buffer(
allocation.param_shape_index());
VLOG(3) << "allocation #" << i << " is a parameter";
continue;
}
@ -99,34 +108,34 @@ Status CpuExecutable::AllocateBuffers(
}
int64 buffer_size = allocation.size();
if (!(*buffers)[i].is_null()) {
if (!owning_buffers[i].is_null()) {
VLOG(3) << "buffer #" << i
<< " is in the preallocated result ShapedBuffer";
} else {
TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate(
device_ordinal, buffer_size));
TF_ASSIGN_OR_RETURN(owning_buffers[i], memory_allocator->Allocate(
device_ordinal, buffer_size));
unowning_buffers[i] = owning_buffers[i].AsDeviceMemoryBase();
VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes ["
<< (*buffers)[i].opaque() << "]";
<< owning_buffers[i].opaque() << "]";
}
// Since the output buffer and all the temporary buffers were written into
// by the JITed code, msan has no way of knowing their memory was
// initialized. Mark them initialized so that msan doesn't flag loads from
// these buffers.
TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size);
TF_ANNOTATE_MEMORY_IS_INITIALIZED(owning_buffers[i].opaque(), buffer_size);
}
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
assignment_->GetUniqueTopLevelOutputSlice());
VLOG(3) << "result index: " << result_slice.index();
return Status::OK();
return {{std::move(unowning_buffers), std::move(owning_buffers)}};
}
Status CpuExecutable::ExecuteComputeFunction(
const ExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile) {
// The calling convention for JITed functions is:
@ -136,17 +145,11 @@ Status CpuExecutable::ExecuteComputeFunction(
//
// result: Points at the result.
// run_options: the ExecutableRunOptions object.
// args_array: An array of pointers, each of which points to a parameter.
// The size of this array is determined by the function's arity
// (ProgramShape).
// temps_array: An array of pointers, each of which points to a temporary
// buffer the computation needs. The size of this array is
// determined by buffer analysis.
// args_array: null
// temps_array: An array of pointers, containing pointers to temporary buffers
// required by the executable adn pointers to entry computation
// parameters.
//
std::vector<const void*> args_array;
for (const ShapedBuffer* argument : arguments) {
args_array.push_back(argument->root_buffer().opaque());
}
uint64 start_micros = tensorflow::Env::Default()->NowMicros();
@ -169,16 +172,14 @@ Status CpuExecutable::ExecuteComputeFunction(
if (VLOG_IS_ON(3)) {
VLOG(3) << "Executing compute function:";
VLOG(3) << tensorflow::strings::Printf(
" func(void* result, void* params[%zu], void* temps[%zu], "
" func(void* result, void* params[null], void* temps[%zu], "
"uint64 profile_counters[%zu])",
args_array.size(), buffer_pointers.size(), profile_counters_size);
buffer_pointers.size(), profile_counters_size);
VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer);
auto ptr_printer = [](string* out, const void* p) {
tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p));
};
VLOG(3) << tensorflow::strings::Printf(
" params = [%s]",
tensorflow::str_util::Join(args_array, ", ", ptr_printer).c_str());
VLOG(3) << " params = nullptr";
VLOG(3) << tensorflow::strings::Printf(
" temps = [%s]",
tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str());
@ -186,8 +187,8 @@ Status CpuExecutable::ExecuteComputeFunction(
profile_counters);
}
compute_function_(result_buffer, run_options, args_array.data(),
buffer_pointers.data(), profile_counters);
compute_function_(result_buffer, run_options, nullptr, buffer_pointers.data(),
profile_counters);
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
@ -254,21 +255,18 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
se::Stream* stream = run_options->stream();
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
std::vector<OwningDeviceMemory> buffers(assignment_->Allocations().size());
TF_RETURN_IF_ERROR(AllocateBuffers(
memory_allocator, stream->parent()->device_ordinal(), &buffers));
std::vector<OwningDeviceMemory> owning_buffers;
std::vector<se::DeviceMemoryBase> unowning_buffers;
unowning_buffers.reserve(buffers.size());
for (auto& buffer : buffers) {
unowning_buffers.push_back(buffer.AsDeviceMemoryBase());
}
TF_RETURN_IF_ERROR(ExecuteComputeFunction(&run_options->run_options(),
arguments, unowning_buffers,
hlo_execution_profile));
TF_ASSIGN_OR_RETURN(
std::tie(unowning_buffers, owning_buffers),
CreateTempArray(memory_allocator, stream->parent()->device_ordinal(),
arguments));
return CreateResultShapedBuffer(run_options, &buffers);
TF_RETURN_IF_ERROR(ExecuteComputeFunction(
&run_options->run_options(), unowning_buffers, hlo_execution_profile));
return CreateResultShapedBuffer(run_options, &owning_buffers);
}
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
@ -284,17 +282,15 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
run_options->stream()->implementation());
se::Stream* stream = run_options->stream();
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
std::vector<OwningDeviceMemory> buffers(assignment_->Allocations().size());
TF_RETURN_IF_ERROR(AllocateBuffers(
memory_allocator, stream->parent()->device_ordinal(), &buffers));
std::vector<OwningDeviceMemory> owning_buffers;
std::vector<se::DeviceMemoryBase> unowning_buffers;
unowning_buffers.reserve(buffers.size());
for (auto& buffer : buffers) {
unowning_buffers.push_back(buffer.AsDeviceMemoryBase());
}
TF_ASSIGN_OR_RETURN(
std::tie(unowning_buffers, owning_buffers),
CreateTempArray(memory_allocator, stream->parent()->device_ordinal(),
arguments));
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
CreateResultShapedBuffer(run_options, &buffers));
CreateResultShapedBuffer(run_options, &owning_buffers));
// At this point, `unowning_buffers` contains unowning pointers to all of our
// buffers, and `buffers` contains owning pointers to the non-live-out
@ -312,7 +308,6 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
struct AsyncRunTask {
CpuExecutable* executable;
ServiceExecutableRunOptions run_options;
std::vector<const ShapedBuffer*> arguments;
std::vector<se::DeviceMemoryBase> unowning_buffers;
std::shared_ptr<std::vector<OwningDeviceMemory>> buffers;
@ -320,15 +315,14 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
// Failing a CHECK here is not great, but I don't see an obvious way to
// return a failed Status asynchronously.
TF_CHECK_OK(executable->ExecuteComputeFunction(
&run_options.run_options(), arguments, unowning_buffers,
&run_options.run_options(), unowning_buffers,
/*hlo_execution_profile=*/nullptr));
}
};
host_stream->EnqueueTask(AsyncRunTask{
this, *run_options,
std::vector<const ShapedBuffer*>(arguments.begin(), arguments.end()),
unowning_buffers,
std::make_shared<std::vector<OwningDeviceMemory>>(std::move(buffers))});
host_stream->EnqueueTask(
AsyncRunTask{this, *run_options, std::move(unowning_buffers),
std::make_shared<std::vector<OwningDeviceMemory>>(
std::move(owning_buffers))});
return std::move(result);
}

View File

@ -85,20 +85,29 @@ class CpuExecutable : public Executable {
const BufferAssignment& buffer_assignment() const { return *assignment_; }
private:
// Allocate buffers required for execution and assign them to the elements of
// "buffers". "buffers" should be sized to the number of buffers in buffer
// assignment. Each vector element corresponds to a particular Index. If
// a vector element already contains a non-null DeviceMemoryBase, then no
// buffer is assigned for this element.
Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator,
int device_ordinal,
std::vector<OwningDeviceMemory>* buffers);
// Creates an array suitable for passing as the "temps" argument to the JIT
// compiled function pointer.
//
// Returns (unowning_buffers, owning_buffers) where:
//
// - unowning_buffers.data() can be passed as the temps argument as-is and
// includes pointers to the scratch storage required by the computation,
// the live-out buffer into which the result will be written and entry
// computation parameters.
//
// - owning_buffers contains owning pointers to the buffers that were
// allocated by this routine. This routine allocates buffers for temporary
// storage and the live-out buffer into which the computation writes it
// result.
StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
std::vector<OwningDeviceMemory>>>
CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
// Calls the generated function performing the computation with the given
// arguments using the supplied buffers.
Status ExecuteComputeFunction(
const ExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile);

View File

@ -92,9 +92,10 @@ tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) {
} // namespace
void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length,
const void* shape,
xla::int32 shape_length) {
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
__xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length,
const void* shape,
xla::int32 shape_length) {
if (VLOG_IS_ON(2)) {
LOG(INFO) << "AcquireInfeedBufferForDequeue: "
<< ShapeString(shape, shape_length);
@ -111,9 +112,11 @@ void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length,
return buffer->data();
}
void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr,
xla::int32 shape_length) {
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length,
void* buffer_ptr,
const void* shape_ptr,
xla::int32 shape_length) {
if (VLOG_IS_ON(2)) {
LOG(INFO) << "ReleaseInfeedBufferAfterDeque: "
<< ShapeString(shape_ptr, shape_length);
@ -125,8 +128,10 @@ void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
std::move(shape));
}
void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
xla::int32 buffer_length, const void* shape_ptr, xla::int32 shape_length) {
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length,
const void* shape_ptr,
xla::int32 shape_length) {
if (VLOG_IS_ON(2)) {
LOG(INFO) << "AcquireOutfeedBufferForPopulation: "
<< ShapeString(shape_ptr, shape_length);
@ -143,9 +148,11 @@ void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
return buffer->data();
}
void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr,
xla::int32 shape_length) {
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(xla::int32 buffer_length,
void* buffer_ptr,
const void* shape_ptr,
xla::int32 shape_length) {
if (VLOG_IS_ON(2)) {
LOG(INFO) << "ReleaseOutfeedBufferAfterPopulation: "
<< ShapeString(shape_ptr, shape_length);

View File

@ -19,6 +19,8 @@ limitations under the License.
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/types.h"
@ -117,9 +119,8 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
ElementwiseSourceIndex(index, *hlo, i)));
operands.push_back(operand_value);
}
return ir_emitter_->EmitScalarCall(hlo->shape().element_type(),
hlo->to_apply(), operands,
llvm_ir::IrName(hlo));
return ir_emitter_->EmitElementalMap(*Cast<HloMapInstruction>(hlo),
operands, llvm_ir::IrName(hlo));
};
}
return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator);

View File

@ -116,6 +116,19 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
computation->root_instruction()->outer_dimension_partitions().size();
}
if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) {
TF_ASSIGN_OR_RETURN(
computation_root_allocation_,
assignment_.GetUniqueTopLevelSlice(computation->root_instruction()));
}
for (const HloInstruction* param : computation->parameter_instructions()) {
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice,
assignment_.GetUniqueTopLevelSlice(param));
computation_parameter_allocations_[param_slice.allocation()->index()] =
param->parameter_number();
}
InitializeIrFunction(function_name);
// The rdtscp instruction is x86 specific. We will fallback to LLVM's generic
// readcyclecounter if it is unavailable.
@ -132,6 +145,8 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
// Delete 'compute_function', finalizing 'ir_function' and restoring caller
// IR insert point.
compute_function_.reset();
computation_root_allocation_ = BufferAllocation::Slice();
computation_parameter_allocations_.clear();
return ir_function;
}
@ -484,23 +499,11 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
return Status::OK();
}
StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForMap(
HloMapInstruction* map, const llvm_ir::IrArray::Index& index) {
llvm::Function* mapped_ir_function =
FindOrDie(emitted_functions_, map->to_apply());
std::vector<llvm::Value*> parameter_addresses;
for (const HloInstruction* operand : map->operands()) {
const llvm_ir::IrArray& array = GetIrArrayFor(operand);
parameter_addresses.push_back(array.EmitArrayElementAddress(index, &b_));
}
return EmitElementFunctionCall(mapped_ir_function, map->shape(),
parameter_addresses, "map_function");
}
Status IrEmitter::HandleMap(HloInstruction* map) {
return EmitTargetElementLoop(map, [&](const llvm_ir::IrArray::Index& index) {
return EmitTargetElementLoopBodyForMap(Cast<HloMapInstruction>(map), index);
});
llvm::Value* IrEmitter::EmitElementalMap(
const HloMapInstruction& map_instr,
tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
tensorflow::StringPiece name) {
return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name);
}
StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
@ -508,9 +511,6 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
const llvm_ir::IrArray::Index& index) {
const HloInstruction* operand = reduce_window->operand(0);
const Window& window = reduce_window->window();
HloComputation* function = reduce_window->to_apply();
// The called computation should have been emitted previously.
llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
// We fold inputs into the accumulator and initialize it to
// the initial value on the reduce_window.
@ -563,11 +563,10 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
// We are not in the padding, so carry out the computation.
llvm_ir::IrArray input_array(GetIrArrayFor(operand));
llvm::Value* input_value_address =
input_array.EmitArrayElementAddress(input_index, &b_);
llvm::Value* result = EmitElementFunctionCall(
reducer_function, reduce_window->shape(),
{accumulator_address, input_value_address}, "reducer_function");
llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_);
llvm::Value* result = EmitThreadLocalCall(
*reduce_window->to_apply(),
{b_.CreateLoad(accumulator_address), input_value}, "reducer_function");
b_.CreateStore(result, accumulator_address);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
@ -623,12 +622,6 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
"Dilation for SelectAndScatter is not implemented on CPU. ");
}
// The select and scatter computations should have been emitted previously.
llvm::Function* select_function =
FindOrDie(emitted_functions_, select_and_scatter->select());
llvm::Function* scatter_function =
FindOrDie(emitted_functions_, select_and_scatter->scatter());
// Pseudo code for select-and-scatter:
//
// initialized_flag is initially off for every window, and is turned on after
@ -733,11 +726,12 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
// If the initialized_flag is true, call the `select` function to potentially
// update the selected value and index with the currently visiting operand.
SetToFirstInsertPoint(if_initialized.true_block, &b_);
const Shape output_shape = ShapeUtil::MakeShape(PRED, {});
llvm::Value* operand_address =
operand_array.EmitArrayElementAddress(operand_index, &b_);
llvm::Value* result = EmitElementFunctionCall(
select_function, output_shape, {selected_value_address, operand_address},
llvm::Value* operand_element = b_.CreateLoad(operand_address);
llvm::Value* result = EmitThreadLocalCall(
*select_and_scatter->select(),
{b_.CreateLoad(selected_value_address), operand_element},
"select_function");
// If the 'select' function returns false, update the selected value and the
@ -764,14 +758,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
selected_index.push_back(b_.CreateLoad(selected_index_address_slot));
}
llvm_ir::IrArray source_array(GetIrArrayFor(source));
llvm::Value* source_value_address =
source_array.EmitArrayElementAddress(source_index, &b_);
llvm::Value* source_value =
source_array.EmitReadArrayElement(source_index, &b_);
llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter));
llvm::Value* output_value_address =
output_array.EmitArrayElementAddress(selected_index, &b_);
llvm::Value* scatter_value = EmitElementFunctionCall(
scatter_function, source->shape(),
{output_value_address, source_value_address}, "scatter_function");
llvm::Value* output_value =
output_array.EmitReadArrayElement(selected_index, &b_);
llvm::Value* scatter_value =
EmitThreadLocalCall(*select_and_scatter->scatter(),
{output_value, source_value}, "scatter_function");
output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_);
SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_);
@ -1248,46 +1242,7 @@ static llvm_ir::IrArray::Index FillReducedDimensionIndex(
Status IrEmitter::HandleParameter(HloInstruction* parameter) {
VLOG(2) << "HandleParameter: " << parameter->ToString();
auto param_number = parameter->parameter_number();
auto param_shape = parameter->shape();
// We have to access the parameter at offset param_number in the params
// array. The code generated here is equivalent to this C code:
//
// i8* param_address_untyped = params[param_number];
// Param* param_address_typed = (Param*)param_address_untyped;
//
// Where Param is the actual element type of the underlying buffer (for
// example, float for an XLA F32 element type).
llvm::Value* params = compute_function_->parameters_arg();
llvm::Value* param_address_offset =
llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
llvm::LoadInst* param_address_untyped = b_.CreateLoad(param_address_offset);
param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped")));
if (is_top_level_computation_ &&
hlo_module_config_.debug_options()
.xla_llvm_enable_invariant_load_metadata()) {
// In the entry computation the parameter slots in the %params argument are
// invariant through program execution. In computations that are called
// from the entry computation (via kWhile, kCall and kConditional) the
// parameter slots are *not* invariant since they're written to by their
// callers.
param_address_untyped->setMetadata(
llvm::LLVMContext::MD_invariant_load,
llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{}));
}
llvm::Value* param_address_typed = b_.CreateBitCast(
param_address_untyped, IrShapeType(param_shape)->getPointerTo());
emitted_value_[parameter] = param_address_typed;
if (!ShapeUtil::IsOpaque(param_shape)) {
AttachAlignmentMetadataForLoad(param_address_untyped, param_shape);
AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape);
}
VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*param_address_typed);
return Status::OK();
return EmitTargetAddressForOp(parameter);
}
// Returns true if the relative order of the unreduced dimensions stays the same
@ -1751,9 +1706,6 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
const HloInstruction* arg = reduce->mutable_operand(0);
const HloInstruction* init_value = reduce->mutable_operand(1);
gtl::ArraySlice<int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
// The called computation should have been emitted previously.
llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
// Initialize an accumulator with init_value.
PrimitiveType accumulator_type = reduce->shape().element_type();
@ -1793,10 +1745,9 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
CHECK(index.end() == it);
// Apply the reduction function to the loaded value.
llvm::Value* input_address =
arg_array.EmitArrayElementAddress(input_index, &b_);
llvm::Value* result = EmitElementFunctionCall(
reducer_function, reduce->shape(), {accumulator_addr, input_address},
llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_);
llvm::Value* result = EmitThreadLocalCall(
*reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element},
"reduce_function");
b_.CreateStore(result, accumulator_addr);
@ -1842,6 +1793,10 @@ Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
return Unimplemented("Send-done is not implemented on CPU.");
}
Status IrEmitter::HandleScatter(HloInstruction*) {
return Unimplemented("Scatter is not implemented on CPUs.");
}
Status IrEmitter::HandleSlice(HloInstruction* slice) {
VLOG(2) << "HandleSlice: " << slice->ToString();
auto operand = slice->operand(0);
@ -2134,18 +2089,13 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
HloComputation* computation = call->to_apply();
llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation);
std::vector<llvm::Value*> parameter_addresses;
for (const HloInstruction* operand : call->operands()) {
parameter_addresses.push_back(GetEmittedValueFor(operand));
}
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call));
if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
// ParallelTaskAssignment assigned partitions, emit call to
// ParallelForkJoin.
std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments(
parameter_addresses, &b_, computation->name(),
{}, &b_, computation->name(),
/*return_value_buffer=*/emitted_value_[call],
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
/*temp_buffers_arg=*/GetTempBuffersArgument(),
@ -2156,8 +2106,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
call_args, root->shape(), root->outer_dimension_partitions(), &b_,
call_ir_function, computation->name()));
} else {
EmitArrayFunctionCallInto(call_ir_function, parameter_addresses,
emitted_value_[call], computation->name());
EmitGlobalCall(*computation, computation->name());
}
return Status::OK();
@ -2238,12 +2187,6 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
const HloInstruction* init = xla_while->operand(0);
emitted_value_[xla_while] = GetEmittedValueFor(init);
// The called computation should have been emitted previously.
llvm::Function* condition_ir_function =
FindOrDie(emitted_functions_, condition);
llvm::Function* body_ir_function =
FindOrDie(emitted_functions_, xla_while->while_body());
// Generating:
// while (Condition(while_result)) {
// // CopyInsertion pass inserts copies which enable 'while_result' to
@ -2260,12 +2203,10 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
// Calls the condition function to determine whether to proceed with the
// body. It must return a bool, so use the scalar call form.
llvm::Value* while_result = GetEmittedValueFor(xla_while);
llvm::Value* while_condition = EmitElementFunctionCall(
condition_ir_function, condition->root_instruction()->shape(),
{while_result}, IrName(xla_while, "cond"));
EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond"));
llvm::Value* while_predicate = b_.CreateICmpNE(
while_condition,
b_.CreateLoad(
GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
// Branches to the body or to the while exit depending on the condition.
@ -2280,8 +2221,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
b_.SetInsertPoint(body_bb);
// Calls the body function.
EmitArrayFunctionCallInto(body_ir_function, {while_result}, while_result,
IrName(xla_while, "body"));
EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body"));
// Finishes with a branch back to the header.
b_.CreateBr(header_bb);
@ -2449,8 +2390,6 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
Status IrEmitter::HandleConditional(HloInstruction* conditional) {
auto pred = conditional->operand(0);
auto true_arg = conditional->operand(1);
auto false_arg = conditional->operand(2);
TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) &&
pred->shape().element_type() == PRED)
<< "Predicate on a Conditional must be bool; got: "
@ -2472,13 +2411,7 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
<< " and "
<< ShapeUtil::HumanString(false_computation->root_instruction()->shape());
llvm::Function* true_function =
FindOrDie(emitted_functions_, true_computation);
llvm::Function* false_function =
FindOrDie(emitted_functions_, false_computation);
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional));
llvm::Value* conditional_result = GetEmittedValueFor(conditional);
// Generating:
// if (pred)
@ -2495,12 +2428,12 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_);
SetToFirstInsertPoint(if_data.true_block, &b_);
EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)},
conditional_result, IrName(conditional, "_true"));
EmitGlobalCall(*conditional->true_computation(),
IrName(conditional, "_true"));
SetToFirstInsertPoint(if_data.false_block, &b_);
EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)},
conditional_result, IrName(conditional, "_false"));
EmitGlobalCall(*conditional->false_computation(),
IrName(conditional, "_false"));
SetToFirstInsertPoint(if_data.after_block, &b_);
return Status::OK();
@ -2701,44 +2634,76 @@ llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
return compute_function_->exec_run_options_arg();
}
llvm::Value* IrEmitter::EmitTempBufferPointer(
llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape) {
llvm::Type* element_type = IrShapeType(target_shape);
// The alignment and number of bytes within the temporary buffer is determined
// by the maximal shape as determined by buffer assignment.
const BufferAllocation& allocation = assignment_.GetAllocation(slice.index());
if (allocation.is_thread_local()) {
const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
if (slice == computation_root_allocation_) {
llvm::Argument* retval = compute_function_->result_arg();
llvm::AttrBuilder attr_builder;
attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
retval->addAttrs(attr_builder);
return retval;
}
auto param_it =
computation_parameter_allocations_.find(slice.allocation()->index());
if (param_it != computation_parameter_allocations_.end()) {
int64 param_number = param_it->second;
// We have to access the parameter at offset param_number in the params
// array. The code generated here is equivalent to this C code:
//
// i8* param_address_untyped = params[param_number];
// Param* param_address_typed = (Param*)param_address_untyped;
//
// Where Param is the actual element type of the underlying buffer (for
// example, float for an XLA F32 element type).
llvm::Value* params = compute_function_->parameters_arg();
llvm::Value* param_address_offset =
llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
llvm::LoadInst* param_address_untyped =
b_.CreateLoad(param_address_offset);
if (!ShapeUtil::IsOpaque(target_shape)) {
AttachAlignmentMetadataForLoad(param_address_untyped, target_shape);
AttachDereferenceableMetadataForLoad(param_address_untyped,
target_shape);
}
return param_address_untyped;
}
// Thread-local allocations should only be assigned a single buffer.
const auto& assigned_buffers = allocation.assigned_buffers();
CHECK_EQ(1, assigned_buffers.size());
const Shape& shape = assigned_buffers.begin()->first->shape();
llvm::AllocaInst*& tempbuf_address =
thread_local_buffers_[{b_.GetInsertBlock()->getParent(), slice}];
if (tempbuf_address == nullptr) {
tempbuf_address = llvm_ir::EmitAllocaAtFunctionEntry(
std::pair<llvm::Function*, BufferAllocation::Slice> key = {
compute_function_->function(), slice};
auto buf_it = thread_local_buffers_.find(key);
if (buf_it == thread_local_buffers_.end()) {
llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry(
IrShapeType(shape),
tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_,
MinimumAlignmentForShape(target_shape));
auto it_inserted_pair = thread_local_buffers_.insert({key, buffer});
CHECK(it_inserted_pair.second);
buf_it = it_inserted_pair.first;
}
return b_.CreateBitCast(tempbuf_address, element_type->getPointerTo());
}
if (allocation.is_constant()) {
return FindOrDie(constant_buffer_to_global_, allocation.index());
}
return buf_it->second;
}();
return b_.CreateBitCast(tempbuf_address,
IrShapeType(target_shape)->getPointerTo());
}
llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape) {
const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
GetTempBuffersArgument(), slice.index(), &b_);
llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr);
if (is_top_level_computation_ &&
hlo_module_config_.debug_options()
if (hlo_module_config_.debug_options()
.xla_llvm_enable_invariant_load_metadata()) {
// In the entry computation the parameter slots in the %params argument are
// invariant through program execution. In computations that are called
// from the entry computation (via kWhile, kCall and kConditional) the
// parameter slots are *not* invariant since they're written to by their
// callers.
tempbuf_address_base->setMetadata(
llvm::LLVMContext::MD_invariant_load,
llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
@ -2753,85 +2718,25 @@ llvm::Value* IrEmitter::EmitTempBufferPointer(
b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
}
return b_.CreateBitCast(tempbuf_address_untyped,
element_type->getPointerTo());
IrShapeType(target_shape)->getPointerTo());
}
// Emits a function call returning a single array element. Allocates space
// for a single element_type value, and loads it after call.
llvm::Value* IrEmitter::EmitElementFunctionCall(
llvm::Function* function, const Shape& return_shape,
gtl::ArraySlice<llvm::Value*> parameter_addresses,
tensorflow::StringPiece name) {
llvm::Value* return_value_buffer = EmitArrayFunctionCall(
function, return_shape, 1, parameter_addresses, name);
return b_.CreateLoad(
return_value_buffer,
AsStringRef(tensorflow::strings::StrCat(name, "_return_value")));
}
// Emits a core function call based on the following pseudo-code.
//
// char** parameter_addresses_buffer =
// allocate buffer with a pointer for each parameter to the function
// for each parameter index, i.e. for i = 0, ..., #parameters:
// parameter_addresses_buffer[i] = parameter_addresses[i]
// call function(return_value_buffer,
// parameter_addresses_buffer,
// temps)
// return return_value_buffer -- address of the return value.
void IrEmitter::EmitArrayFunctionCallInto(
llvm::Function* function, gtl::ArraySlice<llvm::Value*> parameter_addresses,
llvm::Value* return_value_buffer, tensorflow::StringPiece name) {
b_.CreateCall(function,
GetArrayFunctionCallArguments(
parameter_addresses, &b_, name,
/*return_value_buffer=*/return_value_buffer,
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
/*temp_buffers_arg=*/GetTempBuffersArgument(),
/*profile_counters_arg=*/GetProfileCountersArgument()));
}
llvm::Value* IrEmitter::EmitArrayFunctionCall(
llvm::Function* function, const Shape& return_shape, int64 element_count,
gtl::ArraySlice<llvm::Value*> parameter_addresses,
tensorflow::StringPiece name) {
llvm::Value* elements =
llvm::ConstantInt::get(b_.getInt64Ty(), element_count);
PrimitiveType return_type = return_shape.element_type();
llvm::Value* return_value_buffer =
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements,
tensorflow::strings::StrCat(name, "_return_value_address"), &b_,
MinimumAlignmentForPrimitiveType(return_type));
EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer,
name);
return return_value_buffer;
llvm::Value* IrEmitter::EmitTempBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape) {
if (slice.allocation()->is_thread_local()) {
return EmitThreadLocalTempBufferPointer(slice, target_shape);
} else if (slice.allocation()->is_constant()) {
return FindOrDie(constant_buffer_to_global_, slice.allocation()->index());
} else {
return EmitGlobalTempBufferPointer(slice, target_shape);
}
}
Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
llvm::Value* addr;
const Shape& target_shape = op->shape();
if (op == op->parent()->root_instruction()) {
// For the root node, we write directly to the output buffer of the
// function.
llvm::Argument* retval = compute_function_->result_arg();
if ((ShapeUtil::IsArray(target_shape) &&
!ShapeUtil::IsZeroElementArray(target_shape)) ||
(ShapeUtil::IsTuple(target_shape) &&
!ShapeUtil::IsEmptyTuple(target_shape))) {
llvm::AttrBuilder attr_builder;
attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
retval->addAttrs(attr_builder);
}
addr = b_.CreateBitCast(retval, IrShapeType(target_shape)->getPointerTo());
} else {
// For other nodes, we need the temporary buffer allocated for this node to
// write the result into.
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
assignment_.GetUniqueTopLevelSlice(op));
addr = EmitTempBufferPointer(slice, target_shape);
}
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
assignment_.GetUniqueTopLevelSlice(op));
llvm::Value* addr = EmitTempBufferPointer(slice, target_shape);
addr->setName(AsStringRef(IrName(op)));
emitted_value_[op] = addr;
return Status::OK();
@ -2936,20 +2841,69 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
}
StatusOr<llvm::Value*> IrEmitter::EmitScalarCall(
PrimitiveType return_type, HloComputation* computation,
const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name) {
llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation);
std::vector<llvm::Value*> argument_addrs;
for (auto argument : arguments) {
llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry(
argument->getType(), "arg_addr", &b_);
b_.CreateStore(argument, argument_addr);
argument_addrs.push_back(argument_addr);
llvm::Value* IrEmitter::EmitThreadLocalCall(
const HloComputation& callee,
tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
tensorflow::StringPiece name) {
const Shape& return_shape = callee.root_instruction()->shape();
// Lifting this restriction to allow "small" arrays should be easy. Allowing
// larger arrays is difficult because we allocate the buffer for this return
// value on the stack.
CHECK(ShapeUtil::IsScalar(return_shape));
PrimitiveType return_type = return_shape.element_type();
std::vector<llvm::Value*> parameter_addrs;
for (llvm::Value* parameter : parameters) {
CHECK(!parameter->getType()->isPointerTy());
llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry(
parameter->getType(), "arg_addr", &b_);
b_.CreateStore(parameter, parameter_addr);
parameter_addrs.push_back(parameter_addr);
}
return EmitElementFunctionCall(llvm_function,
ShapeUtil::MakeShape(return_type, {}),
argument_addrs, name);
llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(return_type, module_),
tensorflow::strings::StrCat(name, "_retval_addr"), &b_,
MinimumAlignmentForPrimitiveType(return_type));
b_.CreateCall(
FindOrDie(emitted_functions_, &callee),
GetArrayFunctionCallArguments(
parameter_addrs, &b_, name,
/*return_value_buffer=*/return_value_buffer,
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
/*temp_buffers_arg=*/
llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
/*profile_counters_arg=*/GetProfileCountersArgument()));
return b_.CreateLoad(return_value_buffer);
}
void IrEmitter::EmitGlobalCall(const HloComputation& callee,
tensorflow::StringPiece name) {
b_.CreateCall(FindOrDie(emitted_functions_, &callee),
GetArrayFunctionCallArguments(
/*parameter_addresses=*/{}, &b_, name,
/*return_value_buffer=*/
llvm::Constant::getNullValue(b_.getInt8PtrTy()),
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
/*temp_buffers_arg=*/GetTempBuffersArgument(),
/*profile_counters_arg=*/GetProfileCountersArgument()));
}
llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
const HloComputation& callee) {
const HloInstruction* root_inst = callee.root_instruction();
if (root_inst->opcode() == HloOpcode::kOutfeed) {
return llvm::Constant::getNullValue(b_.getInt8PtrTy());
}
const BufferAllocation::Slice root_buffer =
assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie();
return EmitTempBufferPointer(root_buffer, root_inst->shape());
}
} // namespace cpu
} // namespace xla

View File

@ -100,14 +100,15 @@ class IrEmitter : public DfsHloVisitorWithDefault {
llvm::IRBuilder<>* b() { return &b_; }
// Emits a call to `computation` with scalar arguments `arguments`.
StatusOr<llvm::Value*> EmitScalarCall(
PrimitiveType return_type, HloComputation* computation,
const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name);
// Emit an LLVM global variable for every constant buffer allocation.
Status EmitConstantGlobals();
// Emit code to map one element according to `map_instr`.
llvm::Value* EmitElementalMap(
const HloMapInstruction& map_instr,
tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
tensorflow::StringPiece name);
protected:
//
// The following methods implement the DfsHloVisitor interface.
@ -143,13 +144,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandlePad(HloInstruction* pad) override;
Status HandleTuple(HloInstruction* tuple) override;
Status HandleMap(HloInstruction* map) override;
Status HandleFusion(HloInstruction* fusion) override;
Status HandleCall(HloInstruction* call) override;
Status HandleCustomCall(HloInstruction* custom_call) override;
Status HandleWhile(HloInstruction* xla_while) override;
Status HandleConcatenate(HloInstruction* concatenate) override;
Status HandleConditional(HloInstruction* conditional) override;
Status HandleScatter(HloInstruction* scatter) override;
Status HandleAfterAll(HloInstruction* gen_token) override;
Status HandleIota(HloInstruction* iota) override;
Status HandleRng(HloInstruction* rng) override;
@ -218,9 +219,18 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// computation function being emitted by this emitter.
llvm::Value* GetTempBuffersArgument();
// Emits code that computes the address of the given temporary buffer to the
// function. target_shape is the shape of this temporary buffer.
// The returned Value's type is a pointer to element_type.
// Helper for EmitTempBufferPointer.
llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice,
const Shape& target_shape);
// Helper for EmitTempBufferPointer.
llvm::Value* EmitThreadLocalTempBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape);
// Emits code that computes the address of the given buffer allocation slice.
//
// TODO(sanjoy): This should be renamed to reflect that it no longer provides
// access to just temporaries.
llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice,
const Shape& target_shape);
@ -232,44 +242,27 @@ class IrEmitter : public DfsHloVisitorWithDefault {
tensorflow::StringPiece
function_name_suffix); // Used for LLVM IR register names.
// Methods that emit a function call.
// Parameters:
// function - The LLVM function to call.
// return_shape - The return shape of the HLO computation that was used to
// make the function. Not the same as the return type of the function
// in LLVM, since we use output parameters for the return type.
// element_count - number of elements to return (array form only).
// parameter_addresses - pointers to be passed to the function as
// parameters.
// name - used for LLVM IR register names.
// Emits a function call, returning a scalar, often an element of a larger
// array. Returns a Value for the scalar element returned by the function.
llvm::Value* EmitElementFunctionCall(
llvm::Function* function, const Shape& return_shape,
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
// Emits a call to a thread local function (e.g. to the computation nested
// within a reduce or a map). Thread local callees (by definition) only write
// to and read from thread local allocations.
//
// `parameters` holds the *scalar values* that need to be passed to the
// callee. The return value is the scalar returned by the callee.
llvm::Value* EmitThreadLocalCall(
const HloComputation& callee,
tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
tensorflow::StringPiece name);
// Array function call emitter. Stores the function's result into a supplied
// buffer.
// Parameters:
// function - The LLVM function to call.
// parameter_addresses - pointers to be passed to the function as
// parameters.
// return_value - pointer to a buffer where the call result is stored.
// Emits a call to a "global" function (e.g. to the computation nested within
// a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to
// the parameters and return values for these computations so there is no need
// to explicitly pass parameters or return results.
void EmitGlobalCall(const HloComputation& callee,
tensorflow::StringPiece name);
void EmitArrayFunctionCallInto(
llvm::Function* function,
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
llvm::Value* return_value_buffer, tensorflow::StringPiece name);
// Array function call emitter. Returns a Value for the function's return
// value buffer address. The return value buffer is alloca'ed by this
// function.
llvm::Value* EmitArrayFunctionCall(
llvm::Function* function, const Shape& return_shape, int64 element_count,
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
tensorflow::StringPiece name);
// Returns the buffer to which a global call to `callee` would have written
// its result.
llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee);
// Verifies that the element types of all of the given operand instructions
// match and are of one of the given supported types.
@ -408,11 +401,10 @@ class IrEmitter : public DfsHloVisitorWithDefault {
NameUniquer name_uniquer_;
// Map containing all previously emitted computations.
std::map<HloComputation*, llvm::Function*> emitted_functions_;
std::map<const HloComputation*, llvm::Function*> emitted_functions_;
// Map containing all previously emitted thread-local temporary buffers.
std::map<std::pair<llvm::Function*, BufferAllocation::Slice>,
llvm::AllocaInst*>
std::map<std::pair<llvm::Function*, BufferAllocation::Slice>, llvm::Value*>
thread_local_buffers_;
// The following fields track the IR emission state. According to LLVM memory
@ -422,6 +414,16 @@ class IrEmitter : public DfsHloVisitorWithDefault {
std::unique_ptr<IrFunction> compute_function_;
llvm::IRBuilder<> b_;
// The buffer allocation slice for the root of the computation being compiled.
// Only relevant for thread local computations.
BufferAllocation::Slice computation_root_allocation_;
// Maps the buffer allocation slices for the parameters to the computation
// being compiled to their parameter numbers. Only relevant for thread local
// computations.
tensorflow::gtl::FlatMap<BufferAllocation::Index, int64>
computation_parameter_allocations_;
// Maps HLO instructions to their index into the profile counter array.
const std::unordered_map<const HloInstruction*, int64>
instruction_to_profile_idx_;

View File

@ -80,9 +80,16 @@ void IrFunction::Initialize(const string& function_name,
// void function(i8* retval, i8* run_options, i8** params, i8** temps,
// i64* dynamic_loop_bounds, i64* prof_counters)
//
// retval: points to the returned value.
// params: address of an array with pointers to parameters.
// temps: address of an array with pointers to temporary buffers.
// For thread local functions:
// retval: points to the returned value.
// params: address of an array with pointers to parameters.
// temps: is null
//
// For global functions:
// retval: is null
// params: is null
// temps: address of an array with pointers to temporary buffers and entry
// computation parameters.
//
// Therefore, the generated function's signature (FunctionType) is statically
// determined - parameter unpacking is done in code generated into the
@ -196,18 +203,25 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments(
llvm::IRBuilder<>* b, tensorflow::StringPiece name,
llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) {
llvm::Value* parameter_addresses_buffer =
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()),
tensorflow::strings::StrCat(name, "_parameter_addresses"), b);
for (size_t i = 0; i < parameter_addresses.size(); ++i) {
llvm::Value* parameter_as_i8ptr =
b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(),
AsStringRef(tensorflow::strings::StrCat(
name, "_parameter_", i, "_address_as_i8ptr")));
llvm::Value* slot_in_param_addresses =
b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)});
b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
llvm::Value* parameter_addresses_buffer;
if (parameter_addresses.empty()) {
parameter_addresses_buffer =
llvm::Constant::getNullValue(b->getInt8PtrTy()->getPointerTo());
} else {
parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()),
tensorflow::strings::StrCat(name, "_parameter_addresses"), b);
for (size_t i = 0; i < parameter_addresses.size(); ++i) {
llvm::Value* parameter_as_i8ptr =
b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(),
AsStringRef(tensorflow::strings::StrCat(
name, "_parameter_", i, "_address_as_i8ptr")));
llvm::Value* slot_in_param_addresses =
b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)});
b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
}
}
const auto to_int8_ptr = [=](llvm::Value* ptr) {

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@ -58,13 +59,14 @@ using ComputeFunctionType = void (*)(void*, const void*, const void**, void**,
// [partition1_dim2_start]
// [partition1_dim2_limit]
//
void __xla_cpu_runtime_ParallelForkJoin(
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin(
void* result_ptr, const void* run_options_ptr, const void** params,
void** temps, uint64* prof_counters, int32 num_partitions,
int64* partitions, int32 num_partitioned_dims, void* function_ptr) {
VLOG(2) << "ParallelForkJoin ENTRY"
<< " num_partitions: " << num_partitions
<< " num_partitioned_dims: " << num_partitioned_dims;
CHECK_EQ(params, nullptr);
CHECK_GT(num_partitions, 1);
CHECK_GT(num_partitioned_dims, 0);
const xla::ExecutableRunOptions* run_options =
@ -79,9 +81,9 @@ void __xla_cpu_runtime_ParallelForkJoin(
for (int32 i = 1; i < num_partitions; ++i) {
const int64 offset = i * stride;
run_options->intra_op_thread_pool()->enqueueNoNotification(
[i, function, result_ptr, run_options_ptr, params, temps, prof_counters,
[i, function, result_ptr, run_options_ptr, temps, prof_counters,
partitions, offset, &bc]() {
function(result_ptr, run_options_ptr, params, temps,
function(result_ptr, run_options_ptr, nullptr, temps,
&partitions[offset], prof_counters);
bc.DecrementCount();
VLOG(3) << "ParallelForkJoin partition " << i << " done.";

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::int32;
@ -77,27 +78,24 @@ void MatMulImpl(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m,
} // namespace
void __xla_cpu_runtime_EigenMatMulF16(const void* run_options_ptr,
Eigen::half* out, Eigen::half* lhs,
Eigen::half* rhs, int64 m, int64 n,
int64 k, int32 transpose_lhs,
int32 transpose_rhs) {
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16(
const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs,
Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs,
int32 transpose_rhs) {
MatMulImpl<Eigen::half>(run_options_ptr, out, lhs, rhs, m, n, k,
transpose_lhs, transpose_rhs);
}
void __xla_cpu_runtime_EigenMatMulF32(const void* run_options_ptr, float* out,
float* lhs, float* rhs, int64 m, int64 n,
int64 k, int32 transpose_lhs,
int32 transpose_rhs) {
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32(
const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m,
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
MatMulImpl<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
transpose_rhs);
}
void __xla_cpu_runtime_EigenMatMulF64(const void* run_options_ptr, double* out,
double* lhs, double* rhs, int64 m,
int64 n, int64 k, int32 transpose_lhs,
int32 transpose_rhs) {
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64(
const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m,
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
MatMulImpl<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
transpose_rhs);
}

View File

@ -23,6 +23,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool"
#include "tensorflow/core/platform/dynamic_annotations.h"
using tensorflow::int32;
using tensorflow::int64;
@ -74,10 +75,9 @@ void MatMulF64(const void* run_options_ptr, double* out, double* lhs,
} // namespace
void __xla_cpu_runtime_MKLMatMulF32(const void* run_options_ptr, float* out,
float* lhs, float* rhs, int64 m, int64 n,
int64 k, int32 transpose_lhs,
int32 transpose_rhs) {
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLMatMulF32(
const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m,
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
const xla::ExecutableRunOptions* run_options =
static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
// BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread
@ -88,11 +88,11 @@ void __xla_cpu_runtime_MKLMatMulF32(const void* run_options_ptr, float* out,
// Set thread number back to the previous number.
mkl_set_num_threads_local(prev_num_threads);
}
// BLAS GEMM API for 64-bit Matrix Multiplication
void __xla_cpu_runtime_MKLMatMulF64(const void* run_options_ptr, double* out,
double* lhs, double* rhs, int64 m, int64 n,
int64 k, int32 transpose_lhs,
int32 transpose_rhs) {
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLMatMulF64(
const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m,
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
const xla::ExecutableRunOptions* run_options =
static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
// BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread
@ -103,22 +103,26 @@ void __xla_cpu_runtime_MKLMatMulF64(const void* run_options_ptr, double* out,
// Set thread number back to the previous number.
mkl_set_num_threads_local(prev_num_threads);
}
void __xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr,
float* out, float* lhs,
float* rhs, int64 m, int64 n,
int64 k, int32 transpose_lhs,
int32 transpose_rhs) {
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr,
float* out, float* lhs, float* rhs,
int64 m, int64 n, int64 k,
int32 transpose_lhs,
int32 transpose_rhs) {
// Set the thread number to 1 for single threaded excution.
int prev_num_threads = mkl_set_num_threads_local(1);
MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
// Set thread number back to the previous number.
mkl_set_num_threads_local(prev_num_threads);
}
void __xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr,
double* out, double* lhs,
double* rhs, int64 m, int64 n,
int64 k, int32 transpose_lhs,
int32 transpose_rhs) {
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr,
double* out, double* lhs,
double* rhs, int64 m, int64 n,
int64 k, int32 transpose_lhs,
int32 transpose_rhs) {
// Set the thread number to 1 for single threaded excution.
int prev_num_threads = mkl_set_num_threads_local(1);
MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::int32;
@ -71,7 +72,8 @@ void SingleThreadedMatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs,
} // namespace
void __xla_cpu_runtime_EigenSingleThreadedMatMulF16(
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_EigenSingleThreadedMatMulF16(
const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs,
Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs,
int32 transpose_rhs) {
@ -79,16 +81,22 @@ void __xla_cpu_runtime_EigenSingleThreadedMatMulF16(
transpose_lhs, transpose_rhs);
}
void __xla_cpu_runtime_EigenSingleThreadedMatMulF32(
const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m,
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void* run_options_ptr,
float* out, float* lhs,
float* rhs, int64 m, int64 n,
int64 k, int32 transpose_lhs,
int32 transpose_rhs) {
SingleThreadedMatMul<float>(run_options_ptr, out, lhs, rhs, m, n, k,
transpose_lhs, transpose_rhs);
}
void __xla_cpu_runtime_EigenSingleThreadedMatMulF64(
const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m,
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr,
double* out, double* lhs,
double* rhs, int64 m, int64 n,
int64 k, int32 transpose_lhs,
int32 transpose_rhs) {
SingleThreadedMatMul<double>(run_options_ptr, out, lhs, rhs, m, n, k,
transpose_lhs, transpose_rhs);
}

View File

@ -220,7 +220,7 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// The body adds the reduced value of the Infeed data (first tuple element)
// to the previous accumulator, and returns the accumulator and the continue
// flag (second tuple element) as a tuple.
const auto build_body = [this, &result_shape](const Shape& infeed_shape) {
const auto build_body = [&result_shape](const Shape& infeed_shape) {
XlaComputation body;
XlaBuilder builder("body");
auto prev = Parameter(&builder, 0, result_shape, "prev");

View File

@ -233,6 +233,7 @@ class DfsHloVisitorBase {
virtual Status HandleWhile(HloInstructionPtr hlo) = 0;
virtual Status HandleConditional(HloInstructionPtr hlo) = 0;
virtual Status HandleGather(HloInstructionPtr hlo) = 0;
virtual Status HandleScatter(HloInstructionPtr hlo) = 0;
virtual Status HandlePad(HloInstructionPtr hlo) = 0;

View File

@ -194,6 +194,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleGather(HloInstructionPtr gather) override {
return DefaultAction(gather);
}
Status HandleScatter(HloInstructionPtr scatter) override {
return DefaultAction(scatter);
}
Status HandleAfterAll(HloInstructionPtr token) override {
return DefaultAction(token);
}

View File

@ -2134,7 +2134,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
};
default:
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
return [hlo](const IrArray::Index& index) {
return Unimplemented("Unhandled opcode for elemental IR emission: %s",
HloOpcodeString(hlo->opcode()).c_str());
};

View File

@ -636,7 +636,6 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_liveness",
"//tensorflow/compiler/xla/service:call_inliner",
"//tensorflow/compiler/xla/service:conditional_simplifier",
"//tensorflow/compiler/xla/service:dot_decomposer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
"//tensorflow/compiler/xla/service:hlo",
@ -749,6 +748,8 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep
],

View File

@ -43,6 +43,8 @@ Status ForThunk::Initialize(const GpuExecutable& executable,
Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
HloExecutionProfiler* profiler) {
VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for "
<< (hlo_instruction() ? hlo_instruction()->ToString() : "<null>");
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
for (int64 i = 0; i < loop_limit_; ++i) {
profiler->StartHloComputation();

View File

@ -32,16 +32,19 @@ namespace {
// dimensions.
struct MatrixDescriptor {
MatrixDescriptor(se::DeviceMemoryBase matrix_data, bool needs_transpose,
int64 matrix_num_rows, int64 matrix_num_cols)
int64 matrix_num_rows, int64 matrix_num_cols,
int64 matrix_batch_size)
: data(matrix_data),
transpose(needs_transpose),
num_rows(matrix_num_rows),
num_cols(matrix_num_cols) {}
num_cols(matrix_num_cols),
batch_size(matrix_batch_size) {}
se::DeviceMemoryBase data;
bool transpose; // Whether this matrix needs to be transposed.
int64 num_rows;
int64 num_cols;
int64 batch_size;
};
// Performs a gemm call without an explicit algorithm on lhs_matrix and
@ -51,6 +54,9 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
MatrixDescriptor output_matrix, double alpha, se::Stream* stream) {
DCHECK(!output_matrix.transpose);
const int64 batch_size = lhs_matrix.batch_size;
CHECK_EQ(batch_size, rhs_matrix.batch_size);
CHECK_EQ(batch_size, output_matrix.batch_size);
se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
se::DeviceMemory<Element> output_data(output_matrix.data);
@ -61,13 +67,30 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
: se::blas::Transpose::kNoTranspose;
auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
if (batch_size == 1) {
return stream
->ThenBlasGemm(
lhs_transpose, rhs_transpose, output_matrix.num_rows,
output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
/*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0,
&output_data, /*leading dim of output=*/output_matrix.num_rows)
.ok();
}
int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols;
int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols;
int64 output_stride = output_matrix.num_rows * output_matrix.num_cols;
return stream
->ThenBlasGemm(
->ThenBlasGemmStridedBatched(
lhs_transpose, rhs_transpose, output_matrix.num_rows,
output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
/*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0,
&output_data, /*leading dim of output=*/output_matrix.num_rows)
output_matrix.num_cols, /*size of reduce dim=*/k,
/*alpha=*/alpha, lhs_data,
/*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data,
/*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride,
/*beta=*/0.0, &output_data,
/*leading dim of output=*/output_matrix.num_rows, output_stride,
batch_size)
.ok();
}
@ -94,6 +117,10 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
se::blas::ProfileResult* output_profile_result) {
DCHECK(!output_matrix.transpose);
CHECK_EQ(1, lhs_matrix.batch_size);
CHECK_EQ(1, rhs_matrix.batch_size);
CHECK_EQ(1, output_matrix.batch_size);
se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
se::DeviceMemory<Element> output_data(output_matrix.data);
@ -174,6 +201,8 @@ auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm<float>) {
return &DoGemm<float>;
case F64:
return &DoGemm<double>;
case C64:
return &DoGemm<std::complex<float>>;
default:
LOG(FATAL) << "Unsupported type.";
}
@ -187,6 +216,8 @@ auto GetGemmWithAlgorithmFn(PrimitiveType type)
return &DoGemmWithAlgorithm<float>;
case F64:
return &DoGemmWithAlgorithm<double>;
case C64:
return &DoGemmWithAlgorithm<std::complex<float>>;
default:
LOG(FATAL) << "Unsupported type.";
}
@ -199,6 +230,8 @@ auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune<float>) {
return &DoGemmAutotune<float>;
case F64:
return &DoGemmAutotune<double>;
case C64:
return &DoGemmAutotune<std::complex<float>>;
default:
LOG(FATAL) << "Unsupported type.";
}
@ -217,6 +250,8 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) {
return se::blas::ComputationType::kF32;
case F64:
return se::blas::ComputationType::kF64;
case C64:
return se::blas::ComputationType::kComplexF32;
default:
LOG(FATAL) << "Unsupported type.";
}
@ -270,12 +305,37 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::DeviceMemoryBase output_data =
buffer_allocations.GetDeviceAddress(output_buffer_);
DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction());
CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
dim_nums.rhs_batch_dimensions_size());
CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2,
ShapeUtil::Rank(output_shape_));
int64 row_dim = dim_nums.lhs_batch_dimensions_size();
int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1;
int64 batch_size = std::accumulate(output_shape_.dimensions().begin(),
output_shape_.dimensions().end() - 2, 1,
std::multiplies<int64>());
// Check that the batch dims don't cover the last two dims.
for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
CHECK_NE(row_dim, batch_dim);
CHECK_NE(col_dim, batch_dim);
}
// Verify that the non-batch dimensions are minor-most. This is required for
// efficient access.
for (const auto* shape : {&lhs_shape_, &rhs_shape_, &output_shape_}) {
CHECK_LT(shape->layout().minor_to_major(row_dim), 2);
CHECK_LT(shape->layout().minor_to_major(col_dim), 2);
}
// BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between
// matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of
// their layout. Therefore, we should treat dimension 0 as row and dimension 1
// as column when mapping a matrix Dot to BLAS gemm.
int64 output_num_rows = output_shape_.dimensions(0);
int64 output_num_cols = output_shape_.dimensions(1);
int64 output_num_rows = output_shape_.dimensions(row_dim);
int64 output_num_cols = output_shape_.dimensions(col_dim);
// BLAS gemm expects the inputs and the output are in column-major order.
// Therefore, we need to convert dot between row-major matrices to that
@ -298,31 +358,37 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
// the leading dimension of the LHS matrix of gemm is the number of rows in
// B^T and thus the number of columns in B.
auto make_descriptor = [this](se::DeviceMemoryBase data, const Shape& shape,
bool transpose) -> MatrixDescriptor {
bool is_row_major = LayoutUtil::Minor(shape.layout(), 0) != 0;
bool layout_mismatch = LayoutUtil::Minor(shape.layout(), 0) !=
LayoutUtil::Minor(output_shape_.layout(), 0);
return MatrixDescriptor(data, transpose ^ layout_mismatch,
shape.dimensions(is_row_major),
shape.dimensions(!is_row_major));
auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape& shape,
bool transpose) -> MatrixDescriptor {
bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0;
bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) !=
LayoutUtil::Minor(output_shape_.layout(), row_dim);
return MatrixDescriptor(
data, transpose ^ layout_mismatch,
shape.dimensions(row_dim + static_cast<int64>(is_row_major)),
shape.dimensions(row_dim + static_cast<int64>(!is_row_major)),
batch_size);
};
DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction());
const MatrixDescriptor lhs_descriptor = make_descriptor(
lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0);
lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == row_dim);
const MatrixDescriptor rhs_descriptor = make_descriptor(
rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == 1);
rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == col_dim);
// Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to
// autotune this gemm to figure out the best algorithm.
auto launch = [this](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
MatrixDescriptor output_matrix, se::Stream* stream) {
auto launch = [&](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
MatrixDescriptor output_matrix, se::Stream* stream) {
PrimitiveType element_type = output_shape_.element_type();
se::blas::ComputationType computation_type =
GetBlasComputationType(element_type);
// TODO(b/112111608): Implement auto tune for batched gemm.
if (batch_size != 1) {
return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
alpha_, stream);
}
auto thunk_name = [&] {
return hlo_instruction() != nullptr ? hlo_instruction()->ToString()
: "<null>";
@ -368,16 +434,16 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
bool launch_ok;
if (LayoutUtil::Minor(output_shape_.layout(), 0) == 0) {
launch_ok = launch(
lhs_descriptor, rhs_descriptor,
MatrixDescriptor(output_data, false, output_num_rows, output_num_cols),
stream);
if (LayoutUtil::Minor(output_shape_.layout(), row_dim) == 0) {
launch_ok = launch(lhs_descriptor, rhs_descriptor,
MatrixDescriptor(output_data, false, output_num_rows,
output_num_cols, batch_size),
stream);
} else {
launch_ok = launch(
rhs_descriptor, lhs_descriptor,
MatrixDescriptor(output_data, false, output_num_cols, output_num_rows),
stream);
launch_ok = launch(rhs_descriptor, lhs_descriptor,
MatrixDescriptor(output_data, false, output_num_cols,
output_num_rows, batch_size),
stream);
}
if (!launch_ok) {

View File

@ -293,7 +293,7 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
// the respective location in ShapedBuffer.
std::set<se::DeviceMemoryBase> buffers_in_result;
TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachMutableElementWithStatus(
[&buffer_allocations, &buffers_in_result, &shaped_buffer, this](
[&buffer_allocations, &buffers_in_result, this](
const ShapeIndex& index, se::DeviceMemoryBase* device_memory) {
const auto& sources = this->GetRootPointsToSet().element(index);
// The points-to set is unambiguous so the set should be a

View File

@ -176,6 +176,38 @@ Status GpuLayoutAssignment::AddBackendConstraints(
TF_RETURN_IF_ERROR(
AddBackendConstraintsToDnnConvCustomCall(instruction, constraints));
}
// For batched dot we require the default layout.
// TODO(b/112111608): This is overly conservative, the only real restriction
// is that batch dimensions must be major.
if (instruction->opcode() == HloOpcode::kDot &&
ImplementedAsGemm(*instruction) &&
instruction->dot_dimension_numbers().lhs_batch_dimensions_size() > 0) {
// Verify that the batch dims come before the row and col dims.
const DotDimensionNumbers& dim_nums =
instruction->dot_dimension_numbers();
CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
dim_nums.rhs_batch_dimensions_size());
CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2,
ShapeUtil::Rank(instruction->shape()));
for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
CHECK_LT(batch_dim, ShapeUtil::Rank(instruction->shape()) - 2);
}
// Set both inputs and the output to default layout.
Shape op0_shape = instruction->operand(0)->shape();
LayoutUtil::SetToDefaultLayout(&op0_shape);
Shape op1_shape = instruction->operand(1)->shape();
LayoutUtil::SetToDefaultLayout(&op1_shape);
Shape output_shape = instruction->shape();
LayoutUtil::SetToDefaultLayout(&output_shape);
TF_RETURN_IF_ERROR(
constraints->SetOperandLayout(op0_shape, instruction, 0));
TF_RETURN_IF_ERROR(
constraints->SetOperandLayout(op1_shape, instruction, 1));
TF_RETURN_IF_ERROR(
constraints->SetInstructionLayout(output_shape, instruction));
}
}
return Status::OK();
}

View File

@ -20,8 +20,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@ -31,6 +33,8 @@ namespace xla {
namespace gpu {
namespace {
namespace op = xla::testing::opcode_matchers;
using LayoutAssignmentTest = HloTestBase;
TEST_F(LayoutAssignmentTest, Elementwise) {
@ -327,6 +331,33 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) {
}
}
TEST_F(LayoutAssignmentTest, DotLayout) {
const char* hlo_text = R"(
HloModule DotLayout
ENTRY dot {
p0 = f32[8,8,256,64]{3,1,2,0} parameter(0)
p1 = f32[8,8,256,64]{3,1,2,0} parameter(1)
ROOT dot.1330.10585 = f32[8,8,256,256]{3,2,1,0} dot(p0, p1),
lhs_batch_dims={0,1}, lhs_contracting_dims={3},
rhs_batch_dims={0,1}, rhs_contracting_dims={3}
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(hlo_text));
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
GpuLayoutAssignment layout_assignment(&computation_layout,
backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
Shape expected_shape =
ShapeUtil::MakeShapeWithLayout(F32, {8, 8, 256, 64}, {3, 2, 1, 0});
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Dot(op::ShapeWithLayout(expected_shape),
op::ShapeWithLayout(expected_shape)));
}
} // namespace
} // namespace gpu
} // namespace xla

View File

@ -38,24 +38,27 @@ namespace gpu {
namespace {
// Return whether the given shape is a matrix with no padding.
bool IsRank2WithNoPadding(const Shape& shape) {
return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape);
bool IsRank2WithNoPadding(const Shape& shape, int64 batch_dimensions_size) {
return ShapeUtil::Rank(shape) == batch_dimensions_size + 2 &&
!LayoutUtil::IsPadded(shape);
}
// In a gemm operation where output = lhs * rhs, check whether the given shapes
// are valid for the operation.
bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
const Shape& output_shape) {
const Shape& output_shape,
int64 batch_dimensions_size) {
// The inputs and the output must
// 1) be matrices with no padding and a non-zero number of elements,
// 2) have an allowed element type.
PrimitiveType output_primitive_type = output_shape.element_type();
bool type_is_allowed =
(output_primitive_type == F16 || output_primitive_type == F32 ||
output_primitive_type == F64);
return type_is_allowed && IsRank2WithNoPadding(lhs_shape) &&
IsRank2WithNoPadding(rhs_shape) &&
IsRank2WithNoPadding(output_shape) &&
output_primitive_type == F64 || output_primitive_type == C64);
return type_is_allowed &&
IsRank2WithNoPadding(lhs_shape, batch_dimensions_size) &&
IsRank2WithNoPadding(rhs_shape, batch_dimensions_size) &&
IsRank2WithNoPadding(output_shape, batch_dimensions_size) &&
!ShapeUtil::IsZeroElementArray(lhs_shape) &&
!ShapeUtil::IsZeroElementArray(rhs_shape);
}
@ -64,14 +67,15 @@ bool DotImplementedAsGemm(const HloInstruction& dot) {
CHECK_EQ(dot.opcode(), HloOpcode::kDot);
const Shape& lhs_shape = dot.operand(0)->shape();
const Shape& rhs_shape = dot.operand(1)->shape();
const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
// If gemm can accept the operand shapes, use it rather than a custom
// kernel.
if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) {
if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape(),
dim_numbers.lhs_batch_dimensions_size())) {
// The size of the reduction dimension should match. The shape inference
// guarantees this invariant, so the check here is for programming
// errors.
const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
return true;

View File

@ -125,6 +125,10 @@ Status IrEmitter::HandleRecvDone(HloInstruction*) {
return Unimplemented("Recv-done is not implemented on GPU");
}
Status IrEmitter::HandleScatter(HloInstruction*) {
return Unimplemented("Scatter is not implemented on GPUs.");
}
Status IrEmitter::HandleTuple(HloInstruction* tuple) {
std::vector<llvm::Value*> base_ptrs;
for (const HloInstruction* operand : tuple->operands()) {
@ -450,6 +454,9 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
const Shape& lhs_shape = lhs_instruction->shape();
const Shape& rhs_shape = rhs_instruction->shape();
const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
CHECK_EQ(dnums.lhs_batch_dimensions_size(),
dnums.rhs_batch_dimensions_size());
// TODO(b/110211620): Convert to use i32 index_type when it is possible.
llvm::Type* index_type = b_.getInt64Ty();
@ -485,9 +492,15 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
const int64 lhs_reduction_dimension =
ShapeUtil::GetDimensionNumber(lhs_shape, -1);
const int64 rhs_reduction_dimension =
ShapeUtil::Rank(rhs_shape) >= 2
ShapeUtil::Rank(rhs_shape) >= 2 + dnums.lhs_batch_dimensions_size()
? ShapeUtil::GetDimensionNumber(rhs_shape, -2)
: 0;
: dnums.lhs_batch_dimensions_size();
// Check that the batch dims don't cover the last two dims.
for (int64 batch_dim : dnums.lhs_batch_dimensions()) {
CHECK_NE(lhs_reduction_dimension, batch_dim);
CHECK_NE(rhs_reduction_dimension, batch_dim);
}
// Verify the reduction dimension in the two operands are the same size.
TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) ==
@ -502,6 +515,13 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest(
rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
// We don't have to iterate over the batch dimensions in both arrays, simplify
// the loop nest of the rhs.
for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) {
DCHECK(c_linear_search(dnums.lhs_batch_dimensions(), i));
rhs_index[i] = lhs_index[i];
}
// Create the reduction loop which does the sum of products reduction.
std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
/*start_index=*/0,
@ -564,7 +584,9 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
target_index.push_back(lhs_index[dimension]);
}
}
for (size_t dimension = 0; dimension < rhs_index.size(); ++dimension) {
// Skip over the batch dimensions to not have them in the index twice.
for (size_t dimension = dnums.lhs_batch_dimensions_size();
dimension < rhs_index.size(); ++dimension) {
if (dimension != rhs_reduction_dimension) {
target_index.push_back(rhs_index[dimension]);
}

View File

@ -86,6 +86,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleParameter(HloInstruction* parameter) override;
Status HandleReduce(HloInstruction* reduce) override;
Status HandleTuple(HloInstruction* tuple) override;
Status HandleScatter(HloInstruction* scatter) override;
Status HandleSelect(HloInstruction* select) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override;
Status HandleFusion(HloInstruction* fusion) override;

View File

@ -171,40 +171,6 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) {
return DfsHloVisitor::Postprocess(hlo);
}
namespace {
bool ImplementedAsHostToDeviceMemcpy(const BufferAssignment& buffer_assignment,
const HloInstruction& hlo) {
// `hlo` needs to satisfy the following conditions to be implemented as a
// host-to-device cuMemcpy.
//
// 1. `hlo` is a kCopy instruction.
// 2. `hlo`'s only operand is a kConstant instruction.
// 3. `hlo` and its operand have the same shape (thus the same layout too).
// 4. The address of `hlo`'s buffer is known at runtime (without dereferencing
// pointers in a tuple).
return hlo.opcode() == HloOpcode::kCopy &&
hlo.operand(0)->opcode() == HloOpcode::kConstant &&
ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) &&
buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok();
}
bool ImplementedAsDeviceToDeviceMemcpy(
const BufferAssignment& buffer_assignment, const HloInstruction& hlo) {
// `hlo` needs to satisfy three conditions to be implemented as a
// device-to-device cuMemcpy.
//
// 1. `hlo` is a kCopy instruction.
// 2. `hlo` and its operand have the same shape (thus the same layout too).
// 3. `hlo` and its operand have a statically-known buffer assignment
// (constants do not, for instance), which means the source buffer also
// resides on the device.
return hlo.opcode() == HloOpcode::kCopy &&
ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) &&
buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok() &&
buffer_assignment.GetUniqueTopLevelSlice(hlo.operand(0)).ok();
}
} // namespace
llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
const HloInstruction& inst,
tensorflow::gtl::ArraySlice<const BufferAllocation*> args) {
@ -379,11 +345,6 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
}
Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
if (dnums.lhs_batch_dimensions_size() > 0 ||
dnums.rhs_batch_dimensions_size() > 0) {
return Unimplemented("Dot with batch dimensions not implemented.");
}
if (ImplementedAsGemm(*dot)) {
thunk_sequence_->emplace_back(BuildGemmThunk(dot));
return Status::OK();
@ -730,13 +691,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
}
Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
if (ImplementedAsHostToDeviceMemcpy(ir_emitter_context_->buffer_assignment(),
*copy)) {
thunk_sequence_->emplace_back(BuildHostToDeviceCopyThunk(copy));
return Status::OK();
}
if (ImplementedAsDeviceToDeviceMemcpy(
ir_emitter_context_->buffer_assignment(), *copy)) {
CHECK(ShapeUtil::Compatible(copy->operand(0)->shape(), copy->shape()));
const BufferAssignment& buffer_assignment =
ir_emitter_context_->buffer_assignment();
if (LayoutUtil::Equal(copy->operand(0)->shape().layout(),
copy->shape().layout()) &&
buffer_assignment.GetUniqueTopLevelSlice(copy->operand(0)).ok()) {
thunk_sequence_->emplace_back(BuildDeviceToDeviceCopyThunk(copy));
return Status::OK();
}

View File

@ -114,21 +114,20 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path,
// Gets the GPU name as it's known to LLVM for a given compute capability. If
// we see an unrecognized compute capability, we return "sm_30".
static string GetSmName(std::pair<int, int> compute_capability) {
static auto* m = new std::map<std::pair<int, int>, int>(
{{{2, 0}, 20},
{{2, 1}, 21},
{{3, 0}, 30},
{{3, 2}, 32},
{{3, 5}, 35},
{{3, 7}, 37},
{{5, 0}, 50},
{{5, 2}, 52},
{{5, 3}, 53},
{{6, 0}, 60},
{{6, 1}, 61},
{{6, 2}, 62},
// TODO: Change this to 70 once LLVM NVPTX supports it
{{7, 0}, 60}});
static auto* m = new std::map<std::pair<int, int>, int>({
{{3, 0}, 30},
{{3, 2}, 32},
{{3, 5}, 35},
{{3, 7}, 37},
{{5, 0}, 50},
{{5, 2}, 52},
{{5, 3}, 53},
{{6, 0}, 60},
{{6, 1}, 61},
{{6, 2}, 62},
{{7, 0}, 70},
{{7, 2}, 72},
});
int sm_version = 30;
auto it = m->find(compute_capability);
if (it != m->end()) {
@ -329,7 +328,7 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module,
if (linker.linkInModule(
std::move(libdevice_module), llvm::Linker::Flags::LinkOnlyNeeded,
[](Module& M, const StringSet<>& GVS) {
internalizeModule(M, [&M, &GVS](const GlobalValue& GV) {
internalizeModule(M, [&GVS](const GlobalValue& GV) {
return !GV.hasName() || (GVS.count(GV.getName()) == 0);
});
})) {

View File

@ -115,15 +115,23 @@ bool IsInputFusibleReduction(HloInstruction* instr) {
// will be broadcasted and have not been observed to cause data locality issues.
// TODO(b/111977086): Improve reduce emitters to remove this limitation.
bool ReduceFriendlyInputLayouts(HloInstruction* instr) {
std::vector<HloInstruction*> params;
if (instr->opcode() == HloOpcode::kFusion) {
params = instr->fused_parameters();
} else {
for (HloInstruction* operand : instr->operands()) {
params.push_back(operand);
}
}
int64 max_rank = 0;
const Layout* max_rank_layout;
for (HloInstruction* param : instr->fused_parameters()) {
for (HloInstruction* param : params) {
if (ShapeUtil::Rank(param->shape()) > max_rank) {
max_rank = ShapeUtil::Rank(param->shape());
max_rank_layout = &param->shape().layout();
}
}
return c_all_of(instr->fused_parameters(), [&](HloInstruction* param) {
return c_all_of(params, [&](HloInstruction* param) {
return (ShapeUtil::Rank(param->shape()) < max_rank) ||
(LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
});
@ -221,7 +229,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
const bool is_loop_fusion =
producer->opcode() == HloOpcode::kFusion &&
producer->fusion_kind() == HloInstruction::FusionKind::kLoop;
if (!is_loop_fusion) {
if (!producer->IsElementwise() && !is_loop_fusion) {
VLOG(3) << producer->name() << " is not a loop fusion.";
continue;
}

View File

@ -256,6 +256,26 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
op::Tuple(op::Multiply(), op::Divide()));
}
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
ENTRY reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
c0 = f32[] constant(0)
exp = f32[2,2,2]{2,1,0} exponential(p0)
reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add_computation
ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, exp)
})"))
.ValueOrDie();
ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
SCOPED_TRACE(module->ToString());
const HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement()));
const HloInstruction* fusion = root->operand(0)->operand(0);
ASSERT_TRUE(fusion->IsMultiOutputFusion());
EXPECT_THAT(fusion->fused_expression_root(),
op::Tuple(op::Reduce(), op::Exp()));
}
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_add {

View File

@ -34,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
@ -148,7 +147,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// support BF16 operations without directly implementing a BF16 lowering for
// most ops.
pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
pipeline.AddPass<DotDecomposer>();
{
auto& pass =

View File

@ -160,6 +160,8 @@ message HloInstructionProto {
// present for Send and Recv instructions and their SendDone and RecvDone
// partners.
bool is_host_transfer = 47;
xla.ScatterDimensionNumbers scatter_dimension_numbers = 48;
}
// Serialization of HloComputation.

View File

@ -49,9 +49,9 @@ Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) {
// The default number of bytes accessed for an instruction is the sum of the
// sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not
// handle opaque types.
float bytes_accessed = shape_size_(hlo->shape());
float bytes_accessed = GetShapeSize(hlo->shape());
for (const HloInstruction* operand : hlo->operands()) {
bytes_accessed += shape_size_(operand->shape());
bytes_accessed += GetShapeSize(operand->shape());
}
current_properties_[kBytesAccessedKey] = bytes_accessed;
@ -121,6 +121,13 @@ Status HloCostAnalysis::HandleElementwiseOp(
}
}
int64 HloCostAnalysis::GetShapeSize(const Shape& shape) const {
if (!LayoutUtil::HasLayout(shape)) {
return 0;
}
return shape_size_(shape);
}
Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) {
return HandleElementwiseOp(hlo);
}
@ -181,21 +188,21 @@ Status HloCostAnalysis::HandleReverse(const HloInstruction*) {
}
Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) {
current_properties_[kBytesAccessedKey] = shape_size_(slice->shape()) * 2;
current_properties_[kBytesAccessedKey] = GetShapeSize(slice->shape()) * 2;
return Status::OK();
}
Status HloCostAnalysis::HandleDynamicSlice(
const HloInstruction* dynamic_slice) {
current_properties_[kBytesAccessedKey] =
shape_size_(dynamic_slice->shape()) * 2;
GetShapeSize(dynamic_slice->shape()) * 2;
return Status::OK();
}
Status HloCostAnalysis::HandleDynamicUpdateSlice(
const HloInstruction* dynamic_update_slice) {
current_properties_[kBytesAccessedKey] =
shape_size_(dynamic_update_slice->operand(1)->shape()) * 2;
GetShapeSize(dynamic_update_slice->operand(1)->shape()) * 2;
return Status::OK();
}
@ -204,7 +211,7 @@ Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) {
// through them). The memory touched is then only the size of the output
// index table of the tuple.
current_properties_[kBytesAccessedKey] = shape_size_(tuple->shape());
current_properties_[kBytesAccessedKey] = GetShapeSize(tuple->shape());
return Status::OK();
}
@ -526,12 +533,12 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) {
// TODO(b/33004697): Compute correct cost here, taking the actual number of
// replicas into account.
double flops = 0.0;
ShapeUtil::ForEachSubshape(
crs->shape(), [&, this](const Shape& subshape, const ShapeIndex&) {
if (ShapeUtil::IsArray(subshape)) {
flops += ShapeUtil::ElementsIn(subshape);
}
});
ShapeUtil::ForEachSubshape(crs->shape(),
[&](const Shape& subshape, const ShapeIndex&) {
if (ShapeUtil::IsArray(subshape)) {
flops += ShapeUtil::ElementsIn(subshape);
}
});
current_properties_[kFlopsKey] = flops;
return Status::OK();
}
@ -546,15 +553,9 @@ Status HloCostAnalysis::HandleRng(const HloInstruction* random) {
}
Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
// Compute the properties of the fused expression and attribute them to the
// fusion node. Use a dummy shape_size to avoid any errors from trying to
// calculate the size of a shape that does not have a layout, since nodes
// inside fusion nodes do not necessarily have a layout assigned.
ShapeSizeFunction shape_size = [](const Shape& shape) { return 0; };
TF_ASSIGN_OR_RETURN(
current_properties_,
ProcessSubcomputation(fusion->fused_instructions_computation(),
&shape_size));
ProcessSubcomputation(fusion->fused_instructions_computation()));
// Fusion nodes that produce a tuple also produce the entries in the tuple.
// Ignore the memory accessed inside fused ops, since fusion is supposed to
@ -563,11 +564,11 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
ShapeUtil::ForEachSubshape(
fusion->shape(),
[this](const Shape& subshape, const ShapeIndex& /*shape_index*/) {
current_properties_[kBytesAccessedKey] += shape_size_(subshape);
current_properties_[kBytesAccessedKey] += GetShapeSize(subshape);
});
for (const HloInstruction* operand : fusion->operands()) {
current_properties_[kBytesAccessedKey] += shape_size_(operand->shape());
current_properties_[kBytesAccessedKey] += GetShapeSize(operand->shape());
}
return Status::OK();
@ -648,6 +649,11 @@ Status HloCostAnalysis::HandleGather(const HloInstruction* gather) {
return Status::OK();
}
Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) {
// TODO(b/32945756): Compute the properties of the sub-computation.
return Status::OK();
}
Status HloCostAnalysis::FinishVisit(const HloInstruction*) {
return Status::OK();
}
@ -685,11 +691,8 @@ float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const {
}
StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation(
HloComputation* computation, const ShapeSizeFunction* shape_size) {
if (shape_size == nullptr) {
shape_size = &shape_size_;
}
HloCostAnalysis visitor(*shape_size, per_second_rates_);
HloComputation* computation) {
HloCostAnalysis visitor(shape_size_, per_second_rates_);
TF_RETURN_IF_ERROR(computation->Accept(&visitor));
return visitor.properties();
}

View File

@ -104,6 +104,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleWhile(const HloInstruction* xla_while) override;
Status HandleConditional(const HloInstruction* conditional) override;
Status HandleGather(const HloInstruction* gather) override;
Status HandleScatter(const HloInstruction* scatter) override;
Status FinishVisit(const HloInstruction* root) override;
Status Preprocess(const HloInstruction* hlo) override;
@ -149,11 +150,8 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
const Properties& per_second_rates);
// Returns the properties computed from visiting the computation rooted at the
// given hlo. Uses shape_size_ to calculate shape sizes if shape_size is null,
// otherwise uses shape_size_.
StatusOr<Properties> ProcessSubcomputation(
HloComputation* computation,
const ShapeSizeFunction* shape_size = nullptr);
// given hlo.
StatusOr<Properties> ProcessSubcomputation(HloComputation* computation);
// Utility function to handle all element-wise operations.
Status HandleElementwiseOp(const HloInstruction* hlo_instruction);
@ -170,6 +168,10 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
static float GetPropertyForHlo(const HloInstruction& hlo, const string& key,
const HloToProperties& hlo_to_properties);
// Decorates shape_size_ by returning 0 immediately if the shape does not have
// a layout.
int64 GetShapeSize(const Shape& shape) const;
// Function which computes the size of the top-level of a given shape (not
// including nested elements, if any). If null then bytes_accessed methods
// return an error.

View File

@ -2365,7 +2365,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
auto make_cond = [this, &data_shape]() {
auto make_cond = [&data_shape]() {
auto builder = HloComputation::Builder(TestName() + ".Cond");
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));
@ -2374,7 +2374,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
return builder.Build();
};
auto make_body = [this, &data_shape]() {
auto make_body = [&data_shape]() {
auto builder = HloComputation::Builder(TestName() + ".Body");
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));

View File

@ -1481,8 +1481,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
ShapeUtil::Rank(arg->shape()) - dimensions.size());
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
ShapeInference::InferReduceShape(
/*arg=*/arg->shape(),
/*init_value=*/init_value->shape(),
{&arg->shape(), &init_value->shape()},
/*dimensions_to_reduce=*/dimensions,
/*to_apply=*/function->ComputeProgramShape()));
TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape))

View File

@ -1019,6 +1019,8 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
return kWhite;
}
return kGreen;
case HloOpcode::kScatter:
// Do not de-emphasize Scatter, since it involves significant work.
case HloOpcode::kCopy:
// Emphasize copy nodes, which are either physical transposes (and thus
// significant), or copies of read-only buffers (and thus dead weight).

View File

@ -404,6 +404,22 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
*gather_dimension_numbers, gather_window_bounds);
break;
}
case HloOpcode::kScatter: {
TF_RET_CHECK(proto.operand_ids_size() == 3)
<< "Scatter instruction should have 3 operands but sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.has_scatter_dimension_numbers())
<< "Scatter instruction should have ScatterDimensionNumbers set.";
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Scatter instruction should have 1 called computation but sees "
<< proto.called_computation_ids_size();
auto scatter_dimension_numbers = MakeUnique<ScatterDimensionNumbers>(
proto.scatter_dimension_numbers());
instruction =
CreateScatter(proto.shape(), operands(0), operands(1), operands(2),
computations(0), *scatter_dimension_numbers);
break;
}
default: {
instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
@ -812,11 +828,25 @@ HloInstruction::CreateBitcastConvert(const Shape& shape,
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation) {
return MakeUnique<HloReduceInstruction>(
shape, arg, init_value, dimensions_to_reduce, reduce_computation);
auto instruction = WrapUnique(new HloReduceInstruction(
shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
return std::move(instruction);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
tensorflow::gtl::ArraySlice<HloInstruction*> init_values,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation) {
std::vector<HloInstruction*> all_args;
all_args.reserve(operands.size() * 2);
all_args.insert(all_args.end(), operands.begin(), operands.end());
all_args.insert(all_args.end(), init_values.begin(), init_values.end());
return MakeUnique<HloReduceInstruction>(shape, all_args, dimensions_to_reduce,
reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
@ -1062,6 +1092,16 @@ bool HloInstruction::HasSideEffect() const {
gather_dim_numbers, window_bounds);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
const Shape& shape, HloInstruction* operand,
HloInstruction* scatter_indices, HloInstruction* updates,
HloComputation* update_computation,
const ScatterDimensionNumbers& scatter_dim_numbers) {
return MakeUnique<HloScatterInstruction>(shape, operand, scatter_indices,
updates, update_computation,
scatter_dim_numbers);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
const Shape& shape, HloInstruction* operand,
std::unique_ptr<DomainMetadata> operand_side_metadata,
@ -1124,6 +1164,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kDynamicSlice:
case HloOpcode::kSort:
case HloOpcode::kGather:
case HloOpcode::kScatter:
case HloOpcode::kIota:
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
break;
@ -1587,6 +1628,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
case HloOpcode::kGather:
case HloOpcode::kScatter:
LOG(FATAL) << "Base class impl called for opcode with subclass: "
<< opcode();
}
@ -1693,6 +1735,7 @@ HloComputation* HloInstruction::to_apply() const {
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kScatter:
CHECK_EQ(called_computations_.size(), 1);
return called_computations_[0];
default:
@ -1711,6 +1754,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) {
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kScatter:
CHECK_EQ(called_computations_.size(), 1);
called_computations_[0] = computation;
break;
@ -1977,7 +2021,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
} else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
opcode() == HloOpcode::kReduceWindow ||
opcode() == HloOpcode::kReduce ||
opcode() == HloOpcode::kCrossReplicaSum) {
opcode() == HloOpcode::kCrossReplicaSum ||
opcode() == HloOpcode::kScatter) {
extra.push_back(
StrCat("to_apply=", PrintName(to_apply()->name(), options)));
} else if (!called_computations().empty()) {
@ -2013,6 +2058,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kScatter:
extra.push_back(
StrCat("to_apply=\n", to_apply()->ToString(new_options)));
break;
@ -2311,6 +2357,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleSendDone(this);
case HloOpcode::kGather:
return visitor->HandleGather(this);
case HloOpcode::kScatter:
return visitor->HandleScatter(this);
case HloOpcode::kDomain:
return visitor->HandleDomain(this);
case HloOpcode::kAfterAll:
@ -3171,4 +3219,9 @@ tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_window_bounds()
return Cast<HloGatherInstruction>(this)->gather_window_bounds();
}
const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
const {
return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
}
} // namespace xla

View File

@ -541,17 +541,34 @@ class HloInstruction {
int64 dimension);
// Creates a reduce instruction, where the computation (given by the handle)
// is applied successively to every element in operand. That is, if f is the
// function to apply (which either takes 2 [accumulator, value] or 3
// [accumulator, index, value] arguments) and init is a reduction operator
// specified initial value (for example, 0 for addition), then this operation
// will compute:
// f(f(init, [index0], value0), [index1], value1), ...)
// is applied successively to every element in operand. For example, let f be
// the function to apply, which takes 2 arguments, an accumulator and the
// current value. Let init be an initial value (which is normally chosen to be
// the identity element for f, e.g. 0 if f is addition).
// Then the reduce HLO will compute:
// f(f(init, value0), value1), ...)
static std::unique_ptr<HloInstruction> CreateReduce(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation);
// A more general, multiple-argument version of the above.
// The function to apply, f, now takes N arguments:
// [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ...,
// init_valueN], and returns an N-tuple. The performed computation is (for
// commutative and associative f operators) equivalent to:
//
// f_1 = f(init0, ... initN, input0.value0, ..., inputN.value0)
// f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1,
// ..., inputN.value1)
// ...
// TODO(b/112040122): Add support to this in HLO passes and in backends.
static std::unique_ptr<HloInstruction> CreateReduce(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
tensorflow::gtl::ArraySlice<HloInstruction*> init_values,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation);
// Creates a reduce-window instruction, where the computation (given
// by the handle) is applied window-wise at each valid window
// position in the operand.
@ -644,6 +661,12 @@ class HloInstruction {
const GatherDimensionNumbers& gather_dim_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
static std::unique_ptr<HloInstruction> CreateScatter(
const Shape& shape, HloInstruction* operand,
HloInstruction* scatter_indices, HloInstruction* updates,
HloComputation* update_computation,
const ScatterDimensionNumbers& scatter_dim_numbers);
// Creates a kDomain instruction which delimits an HLO domain which have
// the provided user and operand side metadata.
static std::unique_ptr<HloInstruction> CreateDomain(
@ -1014,9 +1037,7 @@ class HloInstruction {
if (sharding_ == nullptr) {
return tensorflow::gtl::optional<int64>();
}
auto device = sharding_->UniqueDevice();
return device.ok() ? device.ValueOrDie()
: tensorflow::gtl::optional<int64>();
return sharding_->UniqueDevice();
}
// Sets the sharding of this operator. Should only be called by HloModule or
// HloComputation methods.
@ -1454,6 +1475,9 @@ class HloInstruction {
// Delegates to HloGatherInstruction::gather_window_bounds.
tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const;
// Delegates to HloScatterInstruction::scatter_dimension_numbers().
const ScatterDimensionNumbers& scatter_dimension_numbers() const;
// Old methods kept for smooth subclassing transition END.
protected:

View File

@ -1425,6 +1425,55 @@ TEST_F(HloInstructionTest, StringifyGather_1) {
"index_vector_dim=2, window_bounds={30,29,28,27,26}");
}
TEST_F(HloInstructionTest, StringifyScatter) {
Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
Shape scatter_indices_tensor_shape =
ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
Shape scatter_updates_shape =
ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
HloComputation::Builder builder("Scatter");
HloInstruction* input = builder.AddInstruction(
HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
HloInstruction* scatter_indices =
builder.AddInstruction(HloInstruction::CreateParameter(
1, scatter_indices_tensor_shape, "scatter_indices"));
HloInstruction* scatter_updates =
builder.AddInstruction(HloInstruction::CreateParameter(
2, scatter_updates_shape, "scatter_updates"));
HloComputation::Builder update_builder("Scatter.update");
update_builder.AddInstruction(
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p1"));
update_builder.AddInstruction(
HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p2"));
auto module = CreateNewModule();
auto* update_computation =
module->AddEmbeddedComputation(update_builder.Build());
HloInstruction* scatter_instruction =
builder.AddInstruction(HloInstruction::CreateScatter(
input_tensor_shape, input, scatter_indices, scatter_updates,
update_computation,
HloScatterInstruction::MakeScatterDimNumbers(
/*update_window_dims=*/{4, 5, 6, 7, 8},
/*inserted_window_dims=*/{},
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/2)));
module->AddEntryComputation(builder.Build());
EXPECT_EQ(
scatter_instruction->ToString(),
"%scatter = f32[50,49,48,47,46]{4,3,2,1,0} "
"scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
"s64[10,9,5,7,6]{4,3,2,1,0} %scatter_indices, "
"f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %scatter_updates), "
"update_window_dims={4,5,6,7,8}, inserted_window_dims={}, "
"scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=2, "
"to_apply=%Scatter.update");
}
TEST_F(HloInstructionTest, CanonnicalStringificationFusion) {
// Tests stringification of a simple op, fusion, while, and conditional.
const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});

View File

@ -438,13 +438,14 @@ HloConcatenateInstruction::CloneWithNewOperandsImpl(
}
HloReduceInstruction::HloReduceInstruction(
const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> args,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation)
: HloInstruction(HloOpcode::kReduce, shape),
dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) {
AppendOperand(arg);
AppendOperand(init_value);
for (HloInstruction* arg : args) {
AppendOperand(arg);
}
AppendComputation(reduce_computation);
}
@ -477,8 +478,8 @@ std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return MakeUnique<HloReduceInstruction>(
shape, new_operands[0], new_operands[1], dimensions(), to_apply());
return MakeUnique<HloReduceInstruction>(shape, new_operands, dimensions(),
to_apply());
}
HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension,
@ -2015,4 +2016,91 @@ std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
gather_window_bounds());
}
HloScatterInstruction::HloScatterInstruction(
const Shape& shape, HloInstruction* operand,
HloInstruction* scatter_indices, HloInstruction* updates,
HloComputation* update_computation,
const ScatterDimensionNumbers& scatter_dim_numbers)
: HloInstruction(HloOpcode::kScatter, shape) {
AppendOperand(operand);
AppendOperand(scatter_indices);
AppendOperand(updates);
AppendComputation(update_computation);
scatter_dimension_numbers_ =
MakeUnique<ScatterDimensionNumbers>(scatter_dim_numbers);
}
string HloScatterInstruction::ScatterDimensionNumbersToString() const {
string update_window_dims =
StrCat("update_window_dims={",
Join(scatter_dimension_numbers().update_window_dims(), ","), "}");
string inserted_window_dims = StrCat(
"inserted_window_dims={",
Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}");
string scatter_dims_to_operand_dims = StrCat(
"scatter_dims_to_operand_dims={",
Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","),
"}");
string index_vector_dim = StrCat(
"index_vector_dim=", scatter_dimension_numbers().index_vector_dim());
return Join<std::initializer_list<string>>(
{update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims,
index_vector_dim},
", ");
}
/* static */ ScatterDimensionNumbers
HloScatterInstruction::MakeScatterDimNumbers(
tensorflow::gtl::ArraySlice<int64> update_window_dims,
tensorflow::gtl::ArraySlice<int64> inserted_window_dims,
tensorflow::gtl::ArraySlice<int64> scatter_dims_to_operand_dims,
int64 index_vector_dim) {
ScatterDimensionNumbers scatter_dim_numbers;
for (int64 update_window_dim : update_window_dims) {
scatter_dim_numbers.add_update_window_dims(update_window_dim);
}
for (int64 inserted_window_dim : inserted_window_dims) {
scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim);
}
for (int64 scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) {
scatter_dim_numbers.add_scatter_dims_to_operand_dims(
scatter_dim_to_operand_dim);
}
scatter_dim_numbers.set_index_vector_dim(index_vector_dim);
return scatter_dim_numbers;
}
HloInstructionProto HloScatterInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
*proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers();
return proto;
}
std::vector<string> HloScatterInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
return {ScatterDimensionNumbersToString()};
}
bool HloScatterInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const {
const auto& casted_other = static_cast<const HloScatterInstruction&>(other);
return protobuf_util::ProtobufEquals(
scatter_dimension_numbers(),
casted_other.scatter_dimension_numbers()) &&
eq_computations(to_apply(), casted_other.to_apply());
}
std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 3);
return MakeUnique<HloScatterInstruction>(
shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
scatter_dimension_numbers());
}
} // namespace xla

View File

@ -331,7 +331,7 @@ class HloConcatenateInstruction : public HloInstruction {
class HloReduceInstruction : public HloInstruction {
public:
explicit HloReduceInstruction(
const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> args,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation);
// Returns the dimension sizes or numbers associated with this instruction.
@ -1198,6 +1198,45 @@ class HloGatherInstruction : public HloInstruction {
std::vector<int64> gather_window_bounds_;
};
class HloScatterInstruction : public HloInstruction {
public:
explicit HloScatterInstruction(
const Shape& shape, HloInstruction* operand,
HloInstruction* scatter_indices, HloInstruction* updates,
HloComputation* update_computation,
const ScatterDimensionNumbers& scatter_dim_numbers);
const ScatterDimensionNumbers& scatter_dimension_numbers() const {
CHECK(scatter_dimension_numbers_ != nullptr);
return *scatter_dimension_numbers_;
}
// Returns the dump string of the scatter dimension numbers.
string ScatterDimensionNumbersToString() const;
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
// Creates an instance of ScatterDimensionNumbers.
static ScatterDimensionNumbers MakeScatterDimNumbers(
tensorflow::gtl::ArraySlice<int64> update_window_dims,
tensorflow::gtl::ArraySlice<int64> inserted_window_dims,
tensorflow::gtl::ArraySlice<int64> scatter_dims_to_operand_dims,
int64 index_vector_dim);
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const override;
std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_

View File

@ -299,9 +299,12 @@ TokKind HloLexer::LexNumberOrPattern() {
static LazyRE2 int_pattern = {R"([-]?\d+)"};
if (RE2::Consume(&consumable, *int_pattern)) {
current_ptr_ = consumable.begin();
tensorflow::strings::safe_strto64(
StringPieceFromPointers(token_start_, current_ptr_), &int64_val_);
return TokKind::kInt;
auto slice = StringPieceFromPointers(token_start_, current_ptr_);
if (tensorflow::strings::safe_strto64(slice, &int64_val_)) {
return TokKind::kInt;
}
LOG(ERROR) << "Failed to parse int literal: " << slice;
return TokKind::kError;
}
static LazyRE2 neg_inf = {"-inf"};

View File

@ -118,6 +118,7 @@ namespace xla {
V(kReverse, "reverse") \
V(kRng, "rng") \
V(kRoundNearestAfz, "round-nearest-afz") \
V(kScatter, "scatter") \
V(kSelect, "select") \
V(kSelectAndScatter, "select-and-scatter") \
V(kSend, "send") \

View File

@ -865,18 +865,28 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kReduce: {
auto loc = lexer_.GetLoc();
optional<HloComputation*> reduce_computation;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&reduce_computation};
optional<std::vector<tensorflow::int64>> dimensions_to_reduce;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions_to_reduce};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
if (operands.size() % 2) {
return Error(loc, StrCat("expects an even number of operands, but has ",
operands.size(), " operands"));
}
instruction = builder->AddInstruction(HloInstruction::CreateReduce(
shape, /*operand=*/operands[0], /*init_value=*/operands[1],
shape, /*operands=*/
tensorflow::gtl::ArraySlice<HloInstruction*>(operands, 0,
operands.size() / 2),
/*init_values=*/
tensorflow::gtl::ArraySlice<HloInstruction*>(
operands, operands.size() / 2, operands.size()),
*dimensions_to_reduce, *reduce_computation));
break;
}
@ -1242,6 +1252,42 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
dim_numbers, *window_bounds));
break;
}
case HloOpcode::kScatter: {
optional<std::vector<tensorflow::int64>> update_window_dims;
attrs["update_window_dims"] = {
/*required=*/true, AttrTy::kBracedInt64List, &update_window_dims};
optional<std::vector<tensorflow::int64>> inserted_window_dims;
attrs["inserted_window_dims"] = {
/*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims};
optional<std::vector<tensorflow::int64>> scatter_dims_to_operand_dims;
attrs["scatter_dims_to_operand_dims"] = {/*required=*/true,
AttrTy::kBracedInt64List,
&scatter_dims_to_operand_dims};
optional<tensorflow::int64> index_vector_dim;
attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
&index_vector_dim};
optional<HloComputation*> update_computation;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&update_computation};
if (!ParseOperands(&operands, /*expected_size=*/3) ||
!ParseAttributes(attrs)) {
return false;
}
ScatterDimensionNumbers dim_numbers =
HloScatterInstruction::MakeScatterDimNumbers(
/*update_window_dims=*/*update_window_dims,
/*inserted_window_dims=*/*inserted_window_dims,
/*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims,
/*index_vector_dim=*/*index_vector_dim);
instruction = builder->AddInstruction(HloInstruction::CreateScatter(
shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1],
/*updates=*/operands[2], *update_computation, dim_numbers));
break;
}
case HloOpcode::kDomain: {
DomainData domain;
attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain};
@ -1590,6 +1636,24 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value,
"value ", value, " is out of range for literal's primitive type ",
PrimitiveType_Name(literal->shape().element_type())));
}
} else if (std::is_unsigned<LiteralNativeT>::value) {
CHECK((std::is_same<ParsedElemT, tensorflow::int64>::value ||
std::is_same<ParsedElemT, bool>::value))
<< "Unimplemented checking for ParsedElemT";
ParsedElemT upper_bound;
if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) {
upper_bound = std::numeric_limits<ParsedElemT>::max();
} else {
upper_bound =
static_cast<ParsedElemT>(std::numeric_limits<LiteralNativeT>::max());
}
if (value > upper_bound || value < 0) {
// Value is out of range for LiteralNativeT.
return TokenError(StrCat(
"value ", value, " is out of range for literal's primitive type ",
PrimitiveType_Name(literal->shape().element_type())));
}
} else if (value > static_cast<ParsedElemT>(
std::numeric_limits<LiteralNativeT>::max()) ||
value < static_cast<ParsedElemT>(

View File

@ -758,6 +758,46 @@ ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5
ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26}
}
)"
},
{
"scatter",
R"(HloModule StringifyScatter
%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
}
ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
%input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
%scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
%updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%add_F32.v3
}
)"
},
{
"ConstantUnsignedNoUnderflow",
R"(HloModule ConstantUnsignedNoUnderflow_module
ENTRY %ConstantUnsignedNoUnderflow () -> u64[] {
ROOT %constant = u64[] constant(1)
}
)"
},
{
"ConstantUnsignedNoOverflow",
R"(HloModule ConstantUnsignedNoOverflow_module
ENTRY %ConstantUnsignedNoOverflow () -> u64[] {
ROOT %constant = u64[] constant(9223372036854775807)
}
)"
},
});
@ -803,6 +843,32 @@ ENTRY ReduceR3ToR2.v3 {
ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
}
)"
},
// tuple reduce
{
"TupleReduce",
R"(HloModule TupleReduce
max_argmax {
value = f32[] parameter(2)
prev_max = f32[] parameter(0)
is_next_larger = pred[] greater-than-or-equal-to(value, prev_max)
max = f32[] select(is_next_larger, value, prev_max)
index = s32[] parameter(3)
prev_argmax = s32[] parameter(1)
argmax = s32[] select(is_next_larger, index, prev_argmax)
ROOT pair = (f32[], s32[]) tuple(max, argmax)
}
ENTRY reduce_entry {
values = f32[1024]{0} parameter(0)
indices = f32[1024]{0} parameter(1)
init_value = f32[] constant(-inf)
init_index = s32[] constant(-1)
ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax
}
)"
},
// infeed/outfeed
@ -1224,6 +1290,40 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] {
"is out of range for literal's primitive type F16");
}
TEST_F(HloParserTest, ConstantUnsignedUnderflow) {
const string original = R"(
HloModule ConstantUnsignedUnderflow_module
ENTRY %ConstantUnsignedUnderflow () -> u64[] {
ROOT %constant = u64[] constant(-1)
})";
auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"is out of range for literal's primitive type U64");
}
TEST_F(HloParserTest, ConstantUnsignedOverflow) {
const string original = R"(
HloModule ConstantUnsignedOverflow_module
ENTRY %ConstantUnsignedOverflow () -> u32[] {
ROOT %constant = u32[] constant(4294967296)
})";
auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"is out of range for literal's primitive type U32");
}
TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) {
const string original = R"(
HloModule ConstantUnsignedOverflow_module
ENTRY %ConstantUnsignedOverflow () -> u64[] {
ROOT %constant = u64[] constant(9223372036854775808)
})";
auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, ConstantWithExp) {
const string original = R"(HloModule ConstantWithExp_module

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_
#include <algorithm>
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -34,9 +36,19 @@ class HloPassFix : public Pass {
StatusOr<bool> Run(HloModule* module) override {
bool changed = false;
bool changed_this_iteration = true;
int64 iteration_count = 0;
int64 limit =
std::max(static_cast<int64>(1000), module->instruction_count());
while (changed_this_iteration) {
TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module));
changed |= changed_this_iteration;
++iteration_count;
if (iteration_count == limit) {
LOG(ERROR)
<< "Unexpectedly number of iterations in HLO passes ("
<< iteration_count
<< ")\nIf compilation hangs here, please file a bug with XLA.";
}
}
return changed;
}

View File

@ -282,7 +282,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
TF_ASSERT_OK_AND_ASSIGN(
SequentialHloOrdering::HloModuleSequence sequence,
ScheduleComputationsInModule(*module,
[&TUPLE_SIZE](const BufferValue& buffer) {
[](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(
buffer.shape(), TUPLE_SIZE);
},

View File

@ -127,15 +127,15 @@ std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
if (IsTuple()) {
for (auto& tuple_element_sharding : tuple_elements()) {
auto unique_device = tuple_element_sharding.UniqueDevice();
if (unique_device.ok()) {
device_map[unique_device.ValueOrDie()] += 1;
if (unique_device) {
device_map[*unique_device] += 1;
}
}
element_count = tuple_elements().size();
} else {
auto unique_device = UniqueDevice();
if (unique_device.ok()) {
device_map[unique_device.ValueOrDie()] += 1;
if (unique_device) {
device_map[*unique_device] += 1;
}
}
if (count != nullptr) {
@ -238,40 +238,31 @@ StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
return Tuple(ShapeTree<HloSharding>(shape, *this));
}
StatusOr<int64> HloSharding::UniqueDevice() const {
tensorflow::gtl::optional<int64> HloSharding::UniqueDevice() const {
if (IsTuple()) {
if (tuple_elements_.empty()) {
return tensorflow::errors::InvalidArgument(
"UniqueDevice() called on empty tuple");
return tensorflow::gtl::nullopt;
}
std::vector<StatusOr<int64>> results;
std::transform(tuple_elements_.begin(), tuple_elements_.end(),
std::back_inserter(results),
[](const HloSharding& s) { return s.UniqueDevice(); });
if (std::all_of(results.begin(), results.end(),
[&](const StatusOr<int64>& s) {
return s.ok() && results[0].ok() &&
s.ValueOrDie() == results[0].ValueOrDie();
})) {
return results[0];
} else {
return tensorflow::errors::InvalidArgument(
"Tuple did not contain a unique device");
tensorflow::gtl::optional<int64> unique_device;
for (auto& tuple_sharding : tuple_elements_) {
auto device = tuple_sharding.UniqueDevice();
if (!device || (unique_device && *device != *unique_device)) {
return tensorflow::gtl::nullopt;
}
unique_device = device;
}
return unique_device;
}
if (!replicated_ && maximal_ && !IsTuple()) {
if (!replicated_ && maximal_) {
return static_cast<int64>(*tile_assignment_.begin());
}
return tensorflow::errors::InvalidArgument(
"UniqueDevice() called on sharding that executes on multiple devices");
return tensorflow::gtl::nullopt;
}
bool HloSharding::HasUniqueDevice() const {
if (IsTuple()) {
return UniqueDevice().status().ok();
} else {
return !IsReplicated() && IsTileMaximal();
}
int64 HloSharding::GetUniqueDevice() const {
auto device = UniqueDevice();
CHECK(device) << "Sharding does not have a unique device: " << *this;
return *device;
}
Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const {

View File

@ -158,12 +158,17 @@ class HloSharding {
// REQUIRES: !IsTuple()
std::vector<int64> TileLimitForDevice(int64 device) const;
// Returns the single device this op operates on.
// REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal()
StatusOr<int64> UniqueDevice() const;
// Returns the single device this op operates on. If the sharding does not
// span a single device, the return value will be empty.
// In order for a sharding to span a single device, every leaf sharding must
// be maximal and not replicated, and the used device must match.
tensorflow::gtl::optional<int64> UniqueDevice() const;
// Retrieves the unique device or fails with a CHECK.
int64 GetUniqueDevice() const;
// Returns true if this op only uses a single device.
bool HasUniqueDevice() const;
bool HasUniqueDevice() const { return UniqueDevice().has_value(); }
// Returns the ShapeTree containing the shardings for each element of this
// tuple, if IsTuple, or a ShapeTree with a single element containing this

View File

@ -51,7 +51,7 @@ TEST_F(HloShardingTest, Replicate) {
EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}),
/*num_devices=*/2));
EXPECT_IS_NOT_OK(sharding.UniqueDevice());
EXPECT_FALSE(sharding.HasUniqueDevice());
}
TEST_F(HloShardingTest, DevicePlacement) {
@ -60,7 +60,7 @@ TEST_F(HloShardingTest, DevicePlacement) {
EXPECT_TRUE(sharding.IsTileMaximal());
EXPECT_FALSE(sharding.UsesDevice(0));
EXPECT_TRUE(sharding.UsesDevice(5));
EXPECT_EQ(5, sharding.UniqueDevice().ValueOrDie());
EXPECT_EQ(5, sharding.GetUniqueDevice());
HloSharding other = HloSharding::Replicate();
EXPECT_NE(other, sharding);
@ -123,7 +123,7 @@ TEST_F(HloShardingTest, Tile) {
EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector<int64>{2, 0}));
EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector<int64>{2, 3}));
EXPECT_IS_NOT_OK(sharding.UniqueDevice());
EXPECT_FALSE(sharding.HasUniqueDevice());
}
}

View File

@ -101,11 +101,11 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction(
}
};
string node_name;
if (debug_options_.xla_hlo_tfgraph_device_scopes() &&
instruction->has_sharding() &&
instruction->sharding().HasUniqueDevice()) {
node_name = StrCat(
"dev", instruction->sharding().UniqueDevice().ConsumeValueOrDie());
if (debug_options_.xla_hlo_tfgraph_device_scopes()) {
auto device = instruction->sharding_unique_device();
if (device) {
node_name = StrCat("dev", *device);
}
}
// If an instruction is fused, put it in the subgraph of the fusion;
// otherwise, put it in the computation subgraph.
@ -215,10 +215,10 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) {
NodeDef* node_def = graph_def_.add_node();
node_def->set_name(GetNodeNameForInstruction(instruction));
node_def->set_op(GetOpDefName(instruction));
if (instruction->has_sharding() &&
instruction->sharding().HasUniqueDevice()) {
TF_ASSIGN_OR_RETURN(int64 device, instruction->sharding().UniqueDevice());
node_def->set_device(GetDeviceName(device));
auto device = instruction->sharding_unique_device();
if (device) {
node_def->set_device(GetDeviceName(*device));
}
SetNodeAttrs(instruction, node_def);
if (instruction->opcode() == HloOpcode::kFusion) {

View File

@ -283,8 +283,7 @@ std::ostream& operator<<(std::ostream& out,
string InstructionValueSet::ToString() const {
string out =
StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n");
ForEachElement([this, &out](const ShapeIndex& index,
const HloValueSet& value_set) {
ForEachElement([&out](const ShapeIndex& index, const HloValueSet& value_set) {
StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n");
});
return out;

View File

@ -224,10 +224,13 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
}
Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
if (!ShapeUtil::IsArray(reduce->shape())) {
return InvalidArgument("Variadic reduce is not supported.");
}
return CheckShape(
reduce,
ShapeInference::InferReduceShape(
reduce->operand(0)->shape(), reduce->operand(1)->shape(),
{&reduce->operand(0)->shape(), &reduce->operand(1)->shape()},
reduce->dimensions(), reduce->to_apply()->ComputeProgramShape()));
}
@ -510,6 +513,15 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) {
gather->gather_dimension_numbers(), gather->gather_window_bounds()));
}
Status ShapeVerifier::HandleScatter(HloInstruction* scatter) {
return CheckShape(
scatter, ShapeInference::InferScatterShape(
scatter->operand(0)->shape(), scatter->operand(1)->shape(),
scatter->operand(2)->shape(),
scatter->to_apply()->ComputeProgramShape(),
scatter->scatter_dimension_numbers()));
}
Status ShapeVerifier::HandleAfterAll(HloInstruction* token) {
std::vector<const Shape*> operand_shapes;
for (const HloInstruction* operand : token->operands()) {

Some files were not shown because too many files have changed in this diff Show More