Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
9523a98466
29
README.md
29
README.md
@ -22,6 +22,8 @@ organization for the purposes of conducting machine learning and deep neural
|
||||
networks research. The system is general enough to be applicable in a wide
|
||||
variety of other domains, as well.
|
||||
|
||||
TensorFlow provides stable Python API and C APIs as well as without API backwards compatibility guarantee like C++, Go, Java, JavaScript and Swift.
|
||||
|
||||
Keep up to date with release announcements and security updates by
|
||||
subscribing to
|
||||
[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce).
|
||||
@ -81,13 +83,13 @@ The TensorFlow project strives to abide by generally accepted best practices in
|
||||
|
||||
| Build Type | Status | Artifacts |
|
||||
| --- | --- | --- |
|
||||
| **Linux CPU** |  | [pypi](https://pypi.org/project/tf-nightly/) |
|
||||
| **Linux GPU** |  | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
|
||||
| **Linux XLA** |  | TBA |
|
||||
| **MacOS** |  | [pypi](https://pypi.org/project/tf-nightly/) |
|
||||
| **Windows CPU** |  | [pypi](https://pypi.org/project/tf-nightly/) |
|
||||
| **Windows GPU** |  | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
|
||||
| **Android** |  | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) |
|
||||
| **Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [pypi](https://pypi.org/project/tf-nightly/) |
|
||||
| **Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
|
||||
| **Linux XLA** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA |
|
||||
| **MacOS** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [pypi](https://pypi.org/project/tf-nightly/) |
|
||||
| **Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [pypi](https://pypi.org/project/tf-nightly/) |
|
||||
| **Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
|
||||
| **Android** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) |
|
||||
|
||||
|
||||
### Community Supported Builds
|
||||
@ -97,17 +99,20 @@ The TensorFlow project strives to abide by generally accepted best practices in
|
||||
| **IBM s390x** | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA |
|
||||
| **IBM ppc64le CPU** | [](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA |
|
||||
| **IBM ppc64le GPU** | [](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA |
|
||||
| **Linux CPU with Intel® MKL-DNN®** | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | TBA |
|
||||
| **Linux CPU with Intel® MKL-DNN** Nightly | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) |
|
||||
| **Linux CPU with Intel® MKL-DNN** Python 2.7<br> **Linux CPU with Intel® MKL-DNN** Python 3.5<br> **Linux CPU with Intel® MKL-DNN** Python 3.6| |[1.9.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp27-cp27mu-linux_x86_64.whl)<br>[1.9.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp35-cp35m-linux_x86_64.whl)<br>[1.9.0 py3.6](https://storage.cloud.google.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp36-cp36m-linux_x86_64.whl) |
|
||||
|
||||
|
||||
## For more information
|
||||
|
||||
* [Tensorflow Blog](https://medium.com/tensorflow)
|
||||
* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si)
|
||||
* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
|
||||
* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730)
|
||||
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
|
||||
* [Tensorflow Twitter](https://twitter.com/tensorflow)
|
||||
* [TensorFlow Website](https://www.tensorflow.org)
|
||||
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
|
||||
* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
|
||||
* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
|
||||
* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730)
|
||||
* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si)
|
||||
|
||||
Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate.
|
||||
|
||||
|
@ -19,7 +19,7 @@
|
||||
* `tf.data`:
|
||||
* `tf.contrib.data.group_by_reducer()` is now available via the public API.
|
||||
* `tf.contrib.data.choose_from_datasets()` is now available via the public API.
|
||||
* Adding `drop_remainder` argument to `tf.data.Dataset.batch()` and `tf.data.Dataset.padded_batch()`, deprecating tf.contrib.data.batch_and_drop_remainder()` and `tf.contrib.data.padded_batch_and_drop_remainder()`.
|
||||
* Adding `drop_remainder` argument to `tf.data.Dataset.batch()` and `tf.data.Dataset.padded_batch()`, deprecating `tf.contrib.data.batch_and_drop_remainder()` and `tf.contrib.data.padded_batch_and_drop_remainder()`.
|
||||
* `tf.estimator`:
|
||||
* `Estimator`s now use custom savers included in `EstimatorSpec` scaffolds for saving SavedModels during export.
|
||||
* `EstimatorSpec` will now add a default prediction output for export if no `export_output` is provided, eliminating the need to explicitly include a `PredictOutput` object in the `model_fn` for simple use-cases.
|
||||
|
@ -451,11 +451,6 @@ filegroup(
|
||||
),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "docs_src",
|
||||
data = glob(["docs_src/**/*.md"]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc",
|
||||
deps = select({
|
||||
@ -599,6 +594,7 @@ exports_files(
|
||||
gen_api_init_files(
|
||||
name = "tensorflow_python_api_gen",
|
||||
srcs = ["api_template.__init__.py"],
|
||||
api_version = 1,
|
||||
root_init_template = "api_template.__init__.py",
|
||||
)
|
||||
|
||||
|
@ -1619,5 +1619,66 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) {
|
||||
TF_DeleteFunction(func1);
|
||||
}
|
||||
|
||||
// This test only works when the TF build includes XLA compiler. One way to set
|
||||
// this up is via bazel build option "--define with_xla_support=true".
|
||||
//
|
||||
// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to
|
||||
// something like TENSORFLOW_CAPI_USE_XLA.
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
TEST_F(CApiFunctionTest, StatelessIf_XLA) {
|
||||
TF_Function* func;
|
||||
const std::string funcName = "BranchFunc";
|
||||
DefineFunction(funcName.c_str(), &func);
|
||||
TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
TF_Operation* feed = Placeholder(host_graph_, s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
TF_Operation* true_cond = ScalarConst(true, host_graph_, s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
TF_OperationDescription* desc =
|
||||
TF_NewOperation(host_graph_, "StatelessIf", "IfNode");
|
||||
TF_AddInput(desc, {true_cond, 0});
|
||||
TF_Output inputs[] = {{feed, 0}};
|
||||
TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs));
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
TF_SetAttrType(desc, "Tcond", TF_BOOL);
|
||||
TF_DataType inputType = TF_INT32;
|
||||
TF_SetAttrTypeList(desc, "Tin", &inputType, 1);
|
||||
TF_SetAttrTypeList(desc, "Tout", &inputType, 1);
|
||||
TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size());
|
||||
TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size());
|
||||
TF_SetDevice(desc, "/device:XLA_CPU:0");
|
||||
auto op = TF_FinishOperation(desc, s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
ASSERT_NE(op, nullptr);
|
||||
|
||||
// Create a session for this graph.
|
||||
CSession csession(host_graph_, s_, /*use_XLA*/ true);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
// Run the graph.
|
||||
csession.SetInputs({{feed, Int32Tensor(17)}});
|
||||
csession.SetOutputs({op});
|
||||
csession.Run(s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
TF_Tensor* out = csession.output_tensor(0);
|
||||
ASSERT_TRUE(out != nullptr);
|
||||
EXPECT_EQ(TF_INT32, TF_TensorType(out));
|
||||
EXPECT_EQ(0, TF_NumDims(out)); // scalar
|
||||
ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
|
||||
int32* output_contents = static_cast<int32*>(TF_TensorData(out));
|
||||
EXPECT_EQ(-17, *output_contents);
|
||||
|
||||
// Clean up
|
||||
csession.CloseAndDelete(s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
TF_DeleteFunction(func);
|
||||
}
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -26,6 +26,10 @@ limitations under the License.
|
||||
using tensorflow::GraphDef;
|
||||
using tensorflow::NodeDef;
|
||||
|
||||
static void BoolDeallocator(void* data, size_t, void* arg) {
|
||||
delete[] static_cast<bool*>(data);
|
||||
}
|
||||
|
||||
static void Int32Deallocator(void* data, size_t, void* arg) {
|
||||
delete[] static_cast<int32_t*>(data);
|
||||
}
|
||||
@ -38,6 +42,14 @@ static void FloatDeallocator(void* data, size_t, void* arg) {
|
||||
delete[] static_cast<float*>(data);
|
||||
}
|
||||
|
||||
TF_Tensor* BoolTensor(bool v) {
|
||||
const int num_bytes = sizeof(bool);
|
||||
bool* values = new bool[1];
|
||||
values[0] = v;
|
||||
return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
|
||||
int64_t num_values = 1;
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
@ -131,6 +143,12 @@ TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
|
||||
return op;
|
||||
}
|
||||
|
||||
TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
|
||||
const char* name) {
|
||||
unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor);
|
||||
return Const(tensor.get(), graph, s, name);
|
||||
}
|
||||
|
||||
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
|
||||
const char* name) {
|
||||
unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor);
|
||||
|
@ -31,6 +31,8 @@ using ::tensorflow::string;
|
||||
typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)>
|
||||
unique_tensor_ptr;
|
||||
|
||||
TF_Tensor* BoolTensor(int32_t v);
|
||||
|
||||
// Create a tensor with values of type TF_INT8 provided by `values`.
|
||||
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values);
|
||||
|
||||
@ -55,6 +57,9 @@ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
|
||||
TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
|
||||
const char* name = "const");
|
||||
|
||||
TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
|
||||
const char* name = "scalar");
|
||||
|
||||
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
|
||||
const char* name = "scalar");
|
||||
|
||||
|
@ -110,7 +110,7 @@ tensorflow::Status GetAllRemoteDevices(
|
||||
|
||||
tensorflow::Status CreateRemoteContexts(
|
||||
const std::vector<string>& remote_workers, int64 rendezvous_id,
|
||||
const tensorflow::ServerDef& server_def,
|
||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
||||
tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
|
||||
for (int i = 0; i < remote_workers.size(); i++) {
|
||||
@ -129,6 +129,7 @@ tensorflow::Status CreateRemoteContexts(
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.set_async(async);
|
||||
request.set_keep_alive_secs(keep_alive_secs);
|
||||
auto* eager_client = remote_eager_workers->GetClient(remote_worker);
|
||||
if (eager_client == nullptr) {
|
||||
return tensorflow::errors::Internal(
|
||||
@ -151,7 +152,8 @@ tensorflow::Status CreateRemoteContexts(
|
||||
}
|
||||
|
||||
tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
const tensorflow::ServerDef& server_def, TFE_Context* ctx) {
|
||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||
TFE_Context* ctx) {
|
||||
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
|
||||
// server object (which currently CHECK-fails) and we miss the error, instead,
|
||||
// we log the error, and then return to allow the user to see the error
|
||||
@ -202,8 +204,8 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
// Initialize remote eager workers.
|
||||
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
remote_workers, rendezvous_id, server_def, remote_eager_workers.get(),
|
||||
ctx->context.Async(), &remote_contexts));
|
||||
remote_workers, rendezvous_id, keep_alive_secs, server_def,
|
||||
remote_eager_workers.get(), ctx->context.Async(), &remote_contexts));
|
||||
|
||||
tensorflow::RemoteRendezvous* r =
|
||||
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
|
||||
@ -222,9 +224,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
|
||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||
|
||||
ctx->context.InitializeRemote(
|
||||
std::move(server), std::move(remote_eager_workers),
|
||||
std::move(remote_device_mgr), remote_contexts, r, device_mgr);
|
||||
ctx->context.InitializeRemote(std::move(server),
|
||||
std::move(remote_eager_workers),
|
||||
std::move(remote_device_mgr), remote_contexts,
|
||||
r, device_mgr, keep_alive_secs);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
#undef LOG_AND_RETURN_IF_ERROR
|
||||
@ -288,6 +291,7 @@ void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); }
|
||||
|
||||
// Set server_def on the context, possibly updating it.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
int keep_alive_secs,
|
||||
const void* proto,
|
||||
size_t proto_len,
|
||||
TF_Status* status) {
|
||||
@ -297,7 +301,8 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
"Invalid tensorflow.ServerDef protocol buffer");
|
||||
return;
|
||||
}
|
||||
status->status = UpdateTFE_ContextWithServerDef(server_def, ctx);
|
||||
status->status =
|
||||
UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx);
|
||||
}
|
||||
|
||||
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
@ -719,6 +724,10 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); }
|
||||
|
||||
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); }
|
||||
|
||||
namespace tensorflow {
|
||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
const tensorflow::AttrValue& default_value,
|
||||
|
@ -124,6 +124,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*,
|
||||
// If the following is set, all servers identified by the
|
||||
// ServerDef must be up when the context is created.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
int keep_alive_secs,
|
||||
const void* proto,
|
||||
size_t proto_len,
|
||||
TF_Status* status);
|
||||
@ -380,6 +381,16 @@ TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx,
|
||||
TF_Buffer* buf,
|
||||
TF_Status* status);
|
||||
|
||||
// Some TF ops need a step container to be set to limit the lifetime of some
|
||||
// resources (mostly TensorArray and Stack, used in while loop gradients in
|
||||
// graph mode). Calling this on a context tells it to start a step.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextStartStep(TFE_Context* ctx);
|
||||
|
||||
// Ends a step. When there is no active step (that is, every started step has
|
||||
// been ended) step containers will be cleared. Note: it is not safe to call
|
||||
// TFE_ContextEndStep while ops which rely on the step container may be running.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextEndStep(TFE_Context* ctx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -151,7 +151,7 @@ void TestRemoteExecute(bool async) {
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
|
||||
@ -239,7 +239,7 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
|
||||
@ -371,7 +371,7 @@ void TestRemoteExecuteChangeServerDef(bool async) {
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
const char remote_device_name[] =
|
||||
@ -397,7 +397,7 @@ void TestRemoteExecuteChangeServerDef(bool async) {
|
||||
ASSERT_TRUE(s.ok()) << s.error_message();
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Create a new tensor_handle.
|
||||
|
@ -379,9 +379,11 @@ tf_cc_test(
|
||||
srcs = ["gradients/math_grad_test.cc"],
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":client_session",
|
||||
":grad_op_registry",
|
||||
":grad_testutil",
|
||||
":gradient_checker",
|
||||
":gradients",
|
||||
":math_grad",
|
||||
":testutil",
|
||||
"//tensorflow/core:lib_internal",
|
||||
|
@ -120,6 +120,24 @@ Status SplitGrad(const Scope& scope, const Operation& op,
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Split", SplitGrad);
|
||||
|
||||
Status FillGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
// y = fill(fill_shape, x)
|
||||
// No gradient returned for the fill_shape argument.
|
||||
grad_outputs->push_back(NoGradient());
|
||||
// The gradient for x (which must be a scalar) is just the sum of
|
||||
// all the gradients from the shape it fills.
|
||||
// We use ReduceSum to implement this, which needs an argument providing
|
||||
// the indices of all the dimensions of the incoming gradient.
|
||||
// grad(x) = reduce_sum(grad(y), [0..rank(grad(y))])
|
||||
auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]),
|
||||
Const(scope, 1));
|
||||
grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims));
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Fill", FillGrad);
|
||||
|
||||
Status DiagGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
|
@ -108,6 +108,14 @@ TEST_F(ArrayGradTest, SplitGrad) {
|
||||
RunTest({x}, {x_shape}, y.output, {y_shape, y_shape});
|
||||
}
|
||||
|
||||
TEST_F(ArrayGradTest, FillGrad) {
|
||||
TensorShape x_shape({});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
|
||||
TensorShape y_shape({2, 5, 3});
|
||||
auto y = Fill(scope_, {2, 5, 3}, x);
|
||||
RunTest(x, x_shape, y, y_shape);
|
||||
}
|
||||
|
||||
TEST_F(ArrayGradTest, DiagGrad) {
|
||||
TensorShape x_shape({5, 2});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
|
||||
|
@ -441,6 +441,22 @@ Status RealDivGrad(const Scope& scope, const Operation& op,
|
||||
}
|
||||
REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
|
||||
|
||||
Status UnsafeDivGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
auto x_1 = ConjugateHelper(scope, op.input(0));
|
||||
auto x_2 = ConjugateHelper(scope, op.input(1));
|
||||
// y = x_1 / x_2
|
||||
// dy/dx_1 = 1/x_2
|
||||
// dy/dx_2 = -x_1/x_2^2
|
||||
auto gx_1 = UnsafeDiv(scope, grad_inputs[0], x_2);
|
||||
auto gx_2 =
|
||||
Mul(scope, grad_inputs[0],
|
||||
UnsafeDiv(scope, UnsafeDiv(scope, Neg(scope, x_1), x_2), x_2));
|
||||
return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
|
||||
}
|
||||
REGISTER_GRADIENT_OP("UnsafeDiv", UnsafeDivGrad);
|
||||
|
||||
Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
@ -1007,6 +1023,26 @@ Status ProdGrad(const Scope& scope, const Operation& op,
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Prod", ProdGrad);
|
||||
|
||||
Status SegmentSumGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
// The SegmentSum operation sums segments of the Tensor that have the same
|
||||
// index in the segment_ids parameter.
|
||||
// i.e z = [2, 3, 4, 5], segment_ids [0, 0, 0, 1]
|
||||
// will produce [2 + 3 + 4, 5] = [9, 5]
|
||||
// The gradient that will flow back to the gather operation will look like
|
||||
// [x1, x2], it will have the same shape as the output of the SegmentSum
|
||||
// operation. The differentiation step of the SegmentSum operation just
|
||||
// broadcast the gradient in order to retrieve the z's shape.
|
||||
// dy/dz = [x1, x1, x1, x2]
|
||||
grad_outputs->push_back(Gather(scope, grad_inputs[0], op.input(1)));
|
||||
|
||||
// stop propagation along segment_ids
|
||||
grad_outputs->push_back(NoGradient());
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("SegmentSum", SegmentSumGrad);
|
||||
|
||||
// MatMulGrad helper function used to compute two MatMul operations
|
||||
// based on input matrix transposition combinations.
|
||||
Status MatMulGradHelper(const Scope& scope, const bool is_batch,
|
||||
|
@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/client/client_session.h"
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/gradient_checker.h"
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
#include "tensorflow/cc/framework/testutil.h"
|
||||
#include "tensorflow/cc/gradients/grad_testutil.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
@ -42,9 +44,11 @@ using ops::Placeholder;
|
||||
using ops::Pow;
|
||||
using ops::Prod;
|
||||
using ops::RealDiv;
|
||||
using ops::SegmentSum;
|
||||
using ops::SquaredDifference;
|
||||
using ops::Sub;
|
||||
using ops::Sum;
|
||||
using ops::UnsafeDiv;
|
||||
|
||||
// TODO(andydavis) Test gradient function against numeric gradients output.
|
||||
// TODO(andydavis) As more gradients are added move common test functions
|
||||
@ -850,6 +854,36 @@ TEST_F(NaryGradTest, RealDiv) {
|
||||
RunTest({x}, {x_shape}, {y}, {x_shape});
|
||||
}
|
||||
|
||||
TEST_F(NaryGradTest, UnsafeDiv) {
|
||||
{
|
||||
TensorShape x_shape({3, 2, 5});
|
||||
const auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
|
||||
// Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
|
||||
// division errors in the numeric estimator used by the gradient checker.
|
||||
const auto y = UnsafeDiv(
|
||||
scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
|
||||
RunTest({x}, {x_shape}, {y}, {x_shape});
|
||||
}
|
||||
{
|
||||
// Return 0 gradient (rather than NaN) for division by zero.
|
||||
const auto x = Placeholder(scope_, DT_FLOAT);
|
||||
const auto zero = Const<float>(scope_, 0.0);
|
||||
const auto y = UnsafeDiv(scope_, x, zero);
|
||||
|
||||
std::vector<Output> grad_outputs;
|
||||
TF_EXPECT_OK(AddSymbolicGradients(scope_, {y}, {x}, &grad_outputs));
|
||||
ClientSession session(scope_);
|
||||
std::vector<Tensor> grad_result;
|
||||
TF_EXPECT_OK(
|
||||
session.Run({{x, {-3.0f, 0.0f, 3.0f}}}, grad_outputs, &grad_result));
|
||||
EXPECT_EQ(grad_result.size(), 1);
|
||||
EXPECT_EQ(grad_result[0].NumElements(), 3);
|
||||
EXPECT_EQ(grad_result[0].flat<float>()(0), 0.0f);
|
||||
EXPECT_EQ(grad_result[0].flat<float>()(1), 0.0f);
|
||||
EXPECT_EQ(grad_result[0].flat<float>()(2), 0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(NaryGradTest, SquaredDifference) {
|
||||
TensorShape x1_shape({3, 2, 5});
|
||||
TensorShape x2_shape({2, 5});
|
||||
@ -898,5 +932,14 @@ TEST_F(NaryGradTest, Prod) {
|
||||
RunTest({x}, {x_shape}, {y}, {y_shape});
|
||||
}
|
||||
|
||||
TEST_F(NaryGradTest, SegmentSum) {
|
||||
TensorShape x_shape({3, 4});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
|
||||
auto y = SegmentSum(scope_, x, {0, 0, 1});
|
||||
// the sum is always on the first dimension
|
||||
TensorShape y_shape({2, 4});
|
||||
RunTest({x}, {x_shape}, {y}, {y_shape});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -170,7 +170,8 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
||||
variables_directory, MetaFilename(kSavedModelVariablesFilename));
|
||||
if (!Env::Default()->FileExists(variables_index_path).ok()) {
|
||||
LOG(INFO) << "The specified SavedModel has no variables; no checkpoints "
|
||||
"were restored.";
|
||||
"were restored. File does not exist: "
|
||||
<< variables_index_path;
|
||||
return Status::OK();
|
||||
}
|
||||
const string variables_path =
|
||||
|
@ -48,6 +48,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client:compile_only_client",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service/cpu:buffer_info_util",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework_internal",
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/str_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -36,6 +37,8 @@ namespace tfcompile {
|
||||
|
||||
namespace {
|
||||
|
||||
using BufferInfo = cpu_function_runtime::BufferInfo;
|
||||
|
||||
bool IsAlpha(char c) {
|
||||
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
|
||||
}
|
||||
@ -85,27 +88,36 @@ Status XLATypeToCpp(xla::PrimitiveType type, string* str) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// total_buffer_bytes returns the sum of each size in `sizes`, skipping -1
|
||||
// values. There are `n` entries in `sizes`.
|
||||
size_t total_buffer_bytes(const intptr_t* sizes, size_t n) {
|
||||
size_t total = 0;
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
if (sizes[i] != -1) {
|
||||
total += sizes[i];
|
||||
}
|
||||
}
|
||||
return total;
|
||||
// Returns the sum of the size of each buffer in `buffer_infos`.
|
||||
size_t TotalBufferBytes(const std::vector<BufferInfo>& buffer_infos) {
|
||||
return std::accumulate(buffer_infos.begin(), buffer_infos.end(), size_t{0},
|
||||
[](size_t size, const BufferInfo& buffer_info) {
|
||||
return size + buffer_info.size();
|
||||
});
|
||||
}
|
||||
|
||||
// Fills in arg_sizes with the byte size of each positional arg.
|
||||
Status ComputeArgSizes(const CompileResult& compile_result,
|
||||
std::vector<int64>* arg_sizes) {
|
||||
const xla::ProgramShape& ps = compile_result.program_shape;
|
||||
for (int i = 0; i < ps.parameters_size(); ++i) {
|
||||
arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf(
|
||||
ps.parameters(i), compile_result.pointer_size));
|
||||
// Returns a vector of BufferInfo instances in `buffer_infos` that are entry
|
||||
// parameter buffers.
|
||||
std::vector<BufferInfo> ExtractEntryParamBufferInfos(
|
||||
const std::vector<BufferInfo>& buffer_infos) {
|
||||
std::vector<BufferInfo> result;
|
||||
std::copy_if(buffer_infos.begin(), buffer_infos.end(),
|
||||
std::back_inserter(result), [](const BufferInfo& buffer_info) {
|
||||
return buffer_info.is_entry_parameter();
|
||||
});
|
||||
return result;
|
||||
}
|
||||
return Status::OK();
|
||||
|
||||
// Returns a vector of BufferInfo instances in `buffer_infos` that are temp
|
||||
// buffers.
|
||||
std::vector<BufferInfo> ExtractTempBufferInfos(
|
||||
const std::vector<BufferInfo>& buffer_infos) {
|
||||
std::vector<BufferInfo> result;
|
||||
std::copy_if(buffer_infos.begin(), buffer_infos.end(),
|
||||
std::back_inserter(result), [](const BufferInfo& buffer_info) {
|
||||
return buffer_info.is_temp_buffer();
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
// Add (from,to) rewrite pairs based on the given shape. These rewrite pairs
|
||||
@ -278,6 +290,25 @@ Status ValidateFeedFetchCppNames(const tf2xla::Config& config) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns a list of C++ expressions that, when executed, will construct the
|
||||
// BufferInfo instances in `buffer_infos`.
|
||||
std::vector<string> BufferInfosToCppExpression(
|
||||
const std::vector<BufferInfo>& buffer_infos) {
|
||||
std::vector<string> buffer_infos_as_strings;
|
||||
std::transform(buffer_infos.begin(), buffer_infos.end(),
|
||||
std::back_inserter(buffer_infos_as_strings),
|
||||
[](const BufferInfo& buffer_info) {
|
||||
std::pair<uint64, uint64> encoded = buffer_info.Encode();
|
||||
string encoded_second_as_str =
|
||||
encoded.second == ~0ULL
|
||||
? "~0ULL"
|
||||
: strings::StrCat(encoded.second, "ULL");
|
||||
return strings::StrCat(
|
||||
"::tensorflow::cpu_function_runtime::BufferInfo({",
|
||||
encoded.first, "ULL, ", encoded_second_as_str, "})");
|
||||
});
|
||||
return buffer_infos_as_strings;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
||||
@ -286,29 +317,35 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||
TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config));
|
||||
const int64 result_index = compile_result.aot->result_buffer_index();
|
||||
const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes();
|
||||
if (result_index < 0 || result_index >= temp_sizes.size()) {
|
||||
const std::vector<BufferInfo>& buffer_infos =
|
||||
compile_result.aot->buffer_infos();
|
||||
const std::vector<int32> arg_index_table =
|
||||
::xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos);
|
||||
std::vector<string> buffer_infos_as_strings =
|
||||
BufferInfosToCppExpression(buffer_infos);
|
||||
if (result_index < 0 || result_index >= buffer_infos.size()) {
|
||||
return errors::InvalidArgument("result index: ", result_index,
|
||||
" is outside the range of temp sizes: [0,",
|
||||
temp_sizes.size(), ")");
|
||||
buffer_infos.size(), ")");
|
||||
}
|
||||
|
||||
// Compute sizes and generate methods.
|
||||
std::vector<int64> arg_sizes;
|
||||
TF_RETURN_IF_ERROR(ComputeArgSizes(compile_result, &arg_sizes));
|
||||
std::vector<BufferInfo> buffer_infos_for_args =
|
||||
ExtractEntryParamBufferInfos(buffer_infos);
|
||||
std::vector<BufferInfo> buffer_infos_for_temps =
|
||||
ExtractTempBufferInfos(buffer_infos);
|
||||
const xla::ProgramShape& ps = compile_result.program_shape;
|
||||
string methods_arg, methods_result;
|
||||
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
|
||||
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
|
||||
const std::vector<intptr_t> iarg(arg_sizes.begin(), arg_sizes.end());
|
||||
const std::vector<intptr_t> itemp(temp_sizes.begin(), temp_sizes.end());
|
||||
const size_t arg_bytes_aligned =
|
||||
cpu_function_runtime::AlignedBufferBytes(iarg.data(), iarg.size());
|
||||
const size_t arg_bytes_total = total_buffer_bytes(iarg.data(), iarg.size());
|
||||
const size_t temp_bytes_aligned =
|
||||
cpu_function_runtime::AlignedBufferBytes(itemp.data(), itemp.size());
|
||||
const size_t temp_bytes_total =
|
||||
total_buffer_bytes(itemp.data(), itemp.size());
|
||||
const size_t arg_bytes_aligned = cpu_function_runtime::AlignedBufferBytes(
|
||||
buffer_infos_for_args.data(), buffer_infos_for_args.size(),
|
||||
/*allocate_entry_params=*/true);
|
||||
const size_t arg_bytes_total = TotalBufferBytes(buffer_infos_for_args);
|
||||
const size_t temp_bytes_aligned = cpu_function_runtime::AlignedBufferBytes(
|
||||
buffer_infos_for_temps.data(), buffer_infos_for_temps.size(),
|
||||
/*allocate_entry_params=*/true);
|
||||
const size_t temp_bytes_total = TotalBufferBytes(buffer_infos_for_temps);
|
||||
|
||||
// Create rewrite strings for namespace start and end.
|
||||
string ns_start;
|
||||
@ -343,8 +380,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
||||
// calling HloProfilePrinter::profile_counters_size.
|
||||
const string assign_profile_counters_size =
|
||||
opts.gen_hlo_profile_printer_data
|
||||
? "data->profile_counters_size = "
|
||||
"data->hlo_profile_printer_data->profile_counters_size();"
|
||||
? "data->set_profile_counters_size("
|
||||
"data->hlo_profile_printer_data()->profile_counters_size());"
|
||||
: "";
|
||||
|
||||
// Use a poor-man's text templating mechanism; first populate the full header
|
||||
@ -414,9 +451,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
|
||||
static constexpr size_t kNumArgs = {{ARG_NUM}};
|
||||
|
||||
// Byte size of each argument buffer. There are kNumArgs entries.
|
||||
static const intptr_t* ArgSizes() {
|
||||
static constexpr intptr_t kArgSizes[kNumArgs] = {{{ARG_SIZES}}};
|
||||
return kArgSizes;
|
||||
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
|
||||
return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
|
||||
}
|
||||
|
||||
// Returns static data used to create an XlaCompiledCpuFunction.
|
||||
@ -424,16 +460,16 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
|
||||
static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
|
||||
XlaCompiledCpuFunction::StaticData* data =
|
||||
new XlaCompiledCpuFunction::StaticData;
|
||||
data->raw_function = {{ENTRY}};
|
||||
data->arg_sizes = ArgSizes();
|
||||
data->num_args = kNumArgs;
|
||||
data->temp_sizes = TempSizes();
|
||||
data->num_temps = kNumTemps;
|
||||
data->result_index = kResultIndex;
|
||||
data->arg_names = StaticArgNames();
|
||||
data->result_names = StaticResultNames();
|
||||
data->program_shape = StaticProgramShape();
|
||||
data->hlo_profile_printer_data = StaticHloProfilePrinterData();
|
||||
data->set_raw_function({{ENTRY}});
|
||||
data->set_buffer_infos(BufferInfos());
|
||||
data->set_num_buffers(kNumBuffers);
|
||||
data->set_arg_index_table(ArgIndexToBufferIndex());
|
||||
data->set_num_args(kNumArgs);
|
||||
data->set_result_index(kResultIndex);
|
||||
data->set_arg_names(StaticArgNames());
|
||||
data->set_result_names(StaticResultNames());
|
||||
data->set_program_shape(StaticProgramShape());
|
||||
data->set_hlo_profile_printer_data(StaticHloProfilePrinterData());
|
||||
{{ASSIGN_PROFILE_COUNTERS_SIZE}}
|
||||
return data;
|
||||
}();
|
||||
@ -482,17 +518,27 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
|
||||
{{METHODS_RESULT}}
|
||||
|
||||
private:
|
||||
// Number of result and temporary buffers for the compiled computation.
|
||||
static constexpr size_t kNumTemps = {{TEMP_NUM}};
|
||||
// Number of buffers for the compiled computation.
|
||||
static constexpr size_t kNumBuffers = {{NUM_BUFFERS}};
|
||||
|
||||
static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() {
|
||||
static const ::tensorflow::cpu_function_runtime::BufferInfo
|
||||
kBufferInfos[kNumBuffers] = {
|
||||
{{BUFFER_INFOS_AS_STRING}}
|
||||
};
|
||||
return kBufferInfos;
|
||||
}
|
||||
|
||||
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
|
||||
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
|
||||
{{ARG_INDEX_TABLE}}
|
||||
};
|
||||
return kArgIndexToBufferIndex;
|
||||
}
|
||||
|
||||
// The 0-based index of the result tuple in the temporary buffers.
|
||||
static constexpr size_t kResultIndex = {{RESULT_INDEX}};
|
||||
|
||||
// Byte size of each result / temporary buffer. There are kNumTemps entries.
|
||||
static const intptr_t* TempSizes() {
|
||||
static constexpr intptr_t kTempSizes[kNumTemps] = {{{TEMP_SIZES}}};
|
||||
return kTempSizes;
|
||||
}
|
||||
|
||||
// Array of names of each positional argument, terminated by nullptr.
|
||||
static const char** StaticArgNames() {{ARG_NAMES_CODE}}
|
||||
|
||||
@ -523,8 +569,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
|
||||
{"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)},
|
||||
{"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)},
|
||||
{"{{ARG_NAMES_CODE}}", arg_names_code},
|
||||
{"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())},
|
||||
{"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")},
|
||||
{"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())},
|
||||
{"{{ARG_INDEX_TABLE}}", str_util::Join(arg_index_table, ", ")},
|
||||
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
|
||||
{"{{CLASS}}", opts.class_name},
|
||||
{"{{DECLS_FROM_OBJ_FILE}}",
|
||||
@ -546,8 +592,9 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
|
||||
{"{{RESULT_NAMES_CODE}}", result_names_code},
|
||||
{"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)},
|
||||
{"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
|
||||
{"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())},
|
||||
{"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}};
|
||||
{"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())},
|
||||
{"{{BUFFER_INFOS_AS_STRING}}",
|
||||
str_util::Join(buffer_infos_as_strings, ",\n")}};
|
||||
str_util::ReplaceAllPairs(header, rewrites);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -32,6 +32,8 @@ namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
namespace {
|
||||
|
||||
using ::tensorflow::cpu_function_runtime::BufferInfo;
|
||||
|
||||
void ExpectErrorContains(const Status& status, StringPiece str) {
|
||||
EXPECT_NE(Status::OK(), status);
|
||||
EXPECT_TRUE(str_util::StrContains(status.error_message(), str))
|
||||
@ -171,8 +173,14 @@ TEST(CodegenTest, Golden) {
|
||||
fetch->mutable_id()->set_node_name("fetch0");
|
||||
fetch->set_name("myfetch");
|
||||
CompileResult compile_result;
|
||||
compile_result.aot.reset(
|
||||
new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5, {}));
|
||||
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
|
||||
{},
|
||||
{BufferInfo::MakeTempBuffer(1),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
|
||||
BufferInfo::MakeTempBuffer(2),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
|
||||
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
|
||||
5, {}));
|
||||
compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
|
||||
{
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
|
||||
|
@ -65,9 +65,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
|
||||
static constexpr size_t kNumArgs = 2;
|
||||
|
||||
// Byte size of each argument buffer. There are kNumArgs entries.
|
||||
static const intptr_t* ArgSizes() {
|
||||
static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96};
|
||||
return kArgSizes;
|
||||
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
|
||||
return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
|
||||
}
|
||||
|
||||
// Returns static data used to create an XlaCompiledCpuFunction.
|
||||
@ -75,16 +74,16 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
|
||||
static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
|
||||
XlaCompiledCpuFunction::StaticData* data =
|
||||
new XlaCompiledCpuFunction::StaticData;
|
||||
data->raw_function = entry_point;
|
||||
data->arg_sizes = ArgSizes();
|
||||
data->num_args = kNumArgs;
|
||||
data->temp_sizes = TempSizes();
|
||||
data->num_temps = kNumTemps;
|
||||
data->result_index = kResultIndex;
|
||||
data->arg_names = StaticArgNames();
|
||||
data->result_names = StaticResultNames();
|
||||
data->program_shape = StaticProgramShape();
|
||||
data->hlo_profile_printer_data = StaticHloProfilePrinterData();
|
||||
data->set_raw_function(entry_point);
|
||||
data->set_buffer_infos(BufferInfos());
|
||||
data->set_num_buffers(kNumBuffers);
|
||||
data->set_arg_index_table(ArgIndexToBufferIndex());
|
||||
data->set_num_args(kNumArgs);
|
||||
data->set_result_index(kResultIndex);
|
||||
data->set_arg_names(StaticArgNames());
|
||||
data->set_result_names(StaticResultNames());
|
||||
data->set_program_shape(StaticProgramShape());
|
||||
data->set_hlo_profile_printer_data(StaticHloProfilePrinterData());
|
||||
|
||||
return data;
|
||||
}();
|
||||
@ -215,17 +214,32 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
|
||||
}
|
||||
|
||||
private:
|
||||
// Number of result and temporary buffers for the compiled computation.
|
||||
static constexpr size_t kNumTemps = 6;
|
||||
// Number of buffers for the compiled computation.
|
||||
static constexpr size_t kNumBuffers = 6;
|
||||
|
||||
static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() {
|
||||
static const ::tensorflow::cpu_function_runtime::BufferInfo
|
||||
kBufferInfos[kNumBuffers] = {
|
||||
::tensorflow::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::tensorflow::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
|
||||
::tensorflow::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}),
|
||||
::tensorflow::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
|
||||
::tensorflow::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}),
|
||||
::tensorflow::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
|
||||
};
|
||||
return kBufferInfos;
|
||||
}
|
||||
|
||||
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
|
||||
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
|
||||
1, 3
|
||||
};
|
||||
return kArgIndexToBufferIndex;
|
||||
}
|
||||
|
||||
// The 0-based index of the result tuple in the temporary buffers.
|
||||
static constexpr size_t kResultIndex = 5;
|
||||
|
||||
// Byte size of each result / temporary buffer. There are kNumTemps entries.
|
||||
static const intptr_t* TempSizes() {
|
||||
static constexpr intptr_t kTempSizes[kNumTemps] = {1, -1, 2, -1, 3, 120};
|
||||
return kTempSizes;
|
||||
}
|
||||
|
||||
// Array of names of each positional argument, terminated by nullptr.
|
||||
static const char** StaticArgNames() {
|
||||
static const char* kNames[] = {"myfeed", nullptr};
|
||||
|
@ -51,11 +51,9 @@ namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
namespace {
|
||||
|
||||
void zero_buffers(void** bufs, const intptr_t* sizes, size_t n) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (sizes[i] != -1) {
|
||||
memset(bufs[i], 0, sizes[i]);
|
||||
}
|
||||
void zero_buffers(XlaCompiledCpuFunction* computation) {
|
||||
for (int i = 0; i < computation->num_args(); ++i) {
|
||||
memset(computation->arg_data(i), 0, computation->arg_size(i));
|
||||
}
|
||||
}
|
||||
|
||||
@ -66,7 +64,7 @@ TEST(TEST_NAME, NoCrash) {
|
||||
|
||||
CPP_CLASS computation;
|
||||
computation.set_thread_pool(&device);
|
||||
zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs);
|
||||
zero_buffers(&computation);
|
||||
|
||||
EXPECT_TRUE(computation.Run());
|
||||
}
|
||||
@ -80,7 +78,7 @@ void BM_NAME(int iters) {
|
||||
|
||||
CPP_CLASS computation;
|
||||
computation.set_thread_pool(&device);
|
||||
zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs);
|
||||
zero_buffers(&computation);
|
||||
|
||||
testing::StartTiming();
|
||||
while (--iters) {
|
||||
|
@ -44,8 +44,8 @@ using ::testing::IsSupersetOf;
|
||||
|
||||
TEST(TFCompileTest, Add) {
|
||||
AddComp add;
|
||||
EXPECT_EQ(add.arg0_data(), add.args()[0]);
|
||||
EXPECT_EQ(add.arg1_data(), add.args()[1]);
|
||||
EXPECT_EQ(add.arg0_data(), add.arg_data(0));
|
||||
EXPECT_EQ(add.arg1_data(), add.arg_data(1));
|
||||
|
||||
add.arg0() = 1;
|
||||
add.arg1() = 2;
|
||||
@ -67,10 +67,10 @@ TEST(TFCompileTest, Add) {
|
||||
EXPECT_EQ(add_const.error_msg(), "");
|
||||
EXPECT_EQ(add_const.arg0(), 123);
|
||||
EXPECT_EQ(add_const.arg0_data()[0], 123);
|
||||
EXPECT_EQ(add_const.arg0_data(), add.args()[0]);
|
||||
EXPECT_EQ(add_const.arg0_data(), add.arg_data(0));
|
||||
EXPECT_EQ(add_const.arg1(), 456);
|
||||
EXPECT_EQ(add_const.arg1_data()[0], 456);
|
||||
EXPECT_EQ(add_const.arg1_data(), add.args()[1]);
|
||||
EXPECT_EQ(add_const.arg1_data(), add.arg_data(1));
|
||||
EXPECT_EQ(add_const.result0(), 579);
|
||||
EXPECT_EQ(add_const.result0_data()[0], 579);
|
||||
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
|
||||
@ -85,8 +85,8 @@ TEST(TFCompileTest, Add_SetArg) {
|
||||
int32 arg_y = 32;
|
||||
add.set_arg0_data(&arg_x);
|
||||
add.set_arg1_data(&arg_y);
|
||||
EXPECT_EQ(add.arg0_data(), add.args()[0]);
|
||||
EXPECT_EQ(add.arg1_data(), add.args()[1]);
|
||||
EXPECT_EQ(add.arg0_data(), add.arg_data(0));
|
||||
EXPECT_EQ(add.arg1_data(), add.arg_data(1));
|
||||
|
||||
EXPECT_TRUE(add.Run());
|
||||
EXPECT_EQ(add.error_msg(), "");
|
||||
@ -97,7 +97,7 @@ TEST(TFCompileTest, Add_SetArg) {
|
||||
|
||||
TEST(TFCompileTest, AddWithCkpt) {
|
||||
AddWithCkptComp add;
|
||||
EXPECT_EQ(add.arg0_data(), add.args()[0]);
|
||||
EXPECT_EQ(add.arg0_data(), add.arg_data(0));
|
||||
|
||||
add.arg0() = 1;
|
||||
EXPECT_TRUE(add.Run());
|
||||
@ -117,7 +117,7 @@ TEST(TFCompileTest, AddWithCkpt) {
|
||||
EXPECT_EQ(add_const.error_msg(), "");
|
||||
EXPECT_EQ(add_const.arg0(), 111);
|
||||
EXPECT_EQ(add_const.arg0_data()[0], 111);
|
||||
EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]);
|
||||
EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0));
|
||||
EXPECT_EQ(add_const.result0(), 153);
|
||||
EXPECT_EQ(add_const.result0_data()[0], 153);
|
||||
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
|
||||
@ -125,7 +125,7 @@ TEST(TFCompileTest, AddWithCkpt) {
|
||||
|
||||
TEST(TFCompileTest, AddWithCkptSaver) {
|
||||
AddWithCkptSaverComp add;
|
||||
EXPECT_EQ(add.arg0_data(), add.args()[0]);
|
||||
EXPECT_EQ(add.arg0_data(), add.arg_data(0));
|
||||
|
||||
add.arg0() = 1;
|
||||
EXPECT_TRUE(add.Run());
|
||||
@ -145,7 +145,7 @@ TEST(TFCompileTest, AddWithCkptSaver) {
|
||||
EXPECT_EQ(add_const.error_msg(), "");
|
||||
EXPECT_EQ(add_const.arg0(), 111);
|
||||
EXPECT_EQ(add_const.arg0_data()[0], 111);
|
||||
EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]);
|
||||
EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0));
|
||||
EXPECT_EQ(add_const.result0(), 153);
|
||||
EXPECT_EQ(add_const.result0_data()[0], 153);
|
||||
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
|
||||
@ -153,9 +153,9 @@ TEST(TFCompileTest, AddWithCkptSaver) {
|
||||
|
||||
TEST(TFCompileTest, Cond) {
|
||||
CondComp cond;
|
||||
EXPECT_EQ(cond.arg0_data(), cond.args()[0]);
|
||||
EXPECT_EQ(cond.arg1_data(), cond.args()[1]);
|
||||
EXPECT_EQ(cond.arg2_data(), cond.args()[2]);
|
||||
EXPECT_EQ(cond.arg0_data(), cond.arg_data(0));
|
||||
EXPECT_EQ(cond.arg1_data(), cond.arg_data(1));
|
||||
EXPECT_EQ(cond.arg2_data(), cond.arg_data(2));
|
||||
cond.arg1() = 10;
|
||||
cond.arg2() = 20;
|
||||
{
|
||||
@ -178,8 +178,8 @@ TEST(TFCompileTest, Cond) {
|
||||
|
||||
TEST(TFCompileTest, Gather) {
|
||||
GatherComp gather;
|
||||
EXPECT_EQ(gather.arg0_data(), gather.args()[0]);
|
||||
EXPECT_EQ(gather.arg1_data(), gather.args()[1]);
|
||||
EXPECT_EQ(gather.arg0_data(), gather.arg_data(0));
|
||||
EXPECT_EQ(gather.arg1_data(), gather.arg_data(1));
|
||||
|
||||
// Successful gather.
|
||||
{
|
||||
@ -202,12 +202,12 @@ TEST(TFCompileTest, Gather) {
|
||||
EXPECT_EQ(gather_const.arg0(i), params[i]);
|
||||
EXPECT_EQ(gather_const.arg0_data()[i], params[i]);
|
||||
}
|
||||
EXPECT_EQ(gather_const.arg0_data(), gather_const.args()[0]);
|
||||
EXPECT_EQ(gather_const.arg0_data(), gather_const.arg_data(0));
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
EXPECT_EQ(gather_const.arg1(i), indices[i]);
|
||||
EXPECT_EQ(gather_const.arg1_data()[i], indices[i]);
|
||||
}
|
||||
EXPECT_EQ(gather_const.arg1_data(), gather_const.args()[1]);
|
||||
EXPECT_EQ(gather_const.arg1_data(), gather_const.arg_data(1));
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
EXPECT_EQ(gather_const.result0(i), results[i]);
|
||||
EXPECT_EQ(gather_const.result0_data()[i], results[i]);
|
||||
@ -222,8 +222,8 @@ TEST(TFCompileTest, MatMul2) {
|
||||
|
||||
foo::bar::MatMulComp matmul;
|
||||
matmul.set_thread_pool(&device);
|
||||
EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]);
|
||||
EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]);
|
||||
EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0));
|
||||
EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1));
|
||||
|
||||
// Test using the argN() methods.
|
||||
{
|
||||
@ -271,12 +271,12 @@ TEST(TFCompileTest, MatMul2) {
|
||||
EXPECT_EQ(matmul_const.arg0(i / 3, i % 3), args[i]);
|
||||
EXPECT_EQ(matmul_const.arg0_data()[i], args[i]);
|
||||
}
|
||||
EXPECT_EQ(matmul_const.arg0_data(), matmul.args()[0]);
|
||||
EXPECT_EQ(matmul_const.arg0_data(), matmul.arg_data(0));
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
EXPECT_EQ(matmul_const.arg1(i / 2, i % 2), args[i + 6]);
|
||||
EXPECT_EQ(matmul_const.arg1_data()[i], args[i + 6]);
|
||||
}
|
||||
EXPECT_EQ(matmul_const.arg1_data(), matmul.args()[1]);
|
||||
EXPECT_EQ(matmul_const.arg1_data(), matmul.arg_data(1));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EXPECT_EQ(matmul_const.result0(i / 2, i % 2), results[i]);
|
||||
EXPECT_EQ(matmul_const.result0_data()[i], results[i]);
|
||||
@ -300,8 +300,8 @@ TEST(TFCompileTest, MatMul2_SetArg) {
|
||||
float arg1[3][2] = {{7, 8}, {9, 10}, {11, 12}};
|
||||
matmul.set_arg0_data(&arg0);
|
||||
matmul.set_arg1_data(&arg1);
|
||||
EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]);
|
||||
EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]);
|
||||
EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0));
|
||||
EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1));
|
||||
|
||||
EXPECT_TRUE(matmul.Run());
|
||||
EXPECT_EQ(matmul.error_msg(), "");
|
||||
@ -319,8 +319,8 @@ TEST(TFCompileTest, MatMulAndAdd1) {
|
||||
|
||||
MatMulAndAddComp muladd;
|
||||
muladd.set_thread_pool(&device);
|
||||
EXPECT_EQ(muladd.arg0_data(), muladd.args()[0]);
|
||||
EXPECT_EQ(muladd.arg1_data(), muladd.args()[1]);
|
||||
EXPECT_EQ(muladd.arg0_data(), muladd.arg_data(0));
|
||||
EXPECT_EQ(muladd.arg1_data(), muladd.arg_data(1));
|
||||
|
||||
// Test methods with positional args and results.
|
||||
{
|
||||
@ -346,12 +346,12 @@ TEST(TFCompileTest, MatMulAndAdd1) {
|
||||
EXPECT_EQ(muladd_const.arg0(i / 2, i % 2), args[i]);
|
||||
EXPECT_EQ(muladd_const.arg0_data()[i], args[i]);
|
||||
}
|
||||
EXPECT_EQ(muladd_const.arg0_data(), muladd.args()[0]);
|
||||
EXPECT_EQ(muladd_const.arg0_data(), muladd.arg_data(0));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EXPECT_EQ(muladd_const.arg1(i / 2, i % 2), args[i + 4]);
|
||||
EXPECT_EQ(muladd_const.arg1_data()[i], args[i + 4]);
|
||||
}
|
||||
EXPECT_EQ(muladd_const.arg1_data(), muladd.args()[1]);
|
||||
EXPECT_EQ(muladd_const.arg1_data(), muladd.arg_data(1));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EXPECT_EQ(muladd_const.result0(i / 2, i % 2), results0[i]);
|
||||
EXPECT_EQ(muladd_const.result0_data()[i], results0[i]);
|
||||
@ -387,12 +387,12 @@ TEST(TFCompileTest, MatMulAndAdd1) {
|
||||
EXPECT_EQ(muladd_const.arg_x(i / 2, i % 2), args[i]);
|
||||
EXPECT_EQ(muladd_const.arg_x_data()[i], args[i]);
|
||||
}
|
||||
EXPECT_EQ(muladd_const.arg_x_data(), muladd.args()[0]);
|
||||
EXPECT_EQ(muladd_const.arg_x_data(), muladd.arg_data(0));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EXPECT_EQ(muladd_const.arg_y(i / 2, i % 2), args[i + 4]);
|
||||
EXPECT_EQ(muladd_const.arg_y_data()[i], args[i + 4]);
|
||||
}
|
||||
EXPECT_EQ(muladd_const.arg_y_data(), muladd.args()[1]);
|
||||
EXPECT_EQ(muladd_const.arg_y_data(), muladd.arg_data(1));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EXPECT_EQ(muladd_const.result_x_y_prod(i / 2, i % 2), results0[i]);
|
||||
EXPECT_EQ(muladd_const.result_x_y_prod_data()[i], results0[i]);
|
||||
@ -407,8 +407,8 @@ TEST(TFCompileTest, MatMulAndAdd1) {
|
||||
TEST(TFCompileTest, Function) {
|
||||
// The function is equivalent to an addition
|
||||
FunctionComp add_fn;
|
||||
EXPECT_EQ(add_fn.arg0_data(), add_fn.args()[0]);
|
||||
EXPECT_EQ(add_fn.arg1_data(), add_fn.args()[1]);
|
||||
EXPECT_EQ(add_fn.arg0_data(), add_fn.arg_data(0));
|
||||
EXPECT_EQ(add_fn.arg1_data(), add_fn.arg_data(1));
|
||||
|
||||
add_fn.arg0() = 1;
|
||||
add_fn.arg1() = 2;
|
||||
@ -451,8 +451,8 @@ TEST(TFCompileTest, AssertEqAndReturnDiff) {
|
||||
// Assert is converted into a no-op in XLA, so there is no failure even if the
|
||||
// two args are different.
|
||||
AssertComp assert;
|
||||
EXPECT_EQ(assert.arg0_data(), assert.args()[0]);
|
||||
EXPECT_EQ(assert.arg1_data(), assert.args()[1]);
|
||||
EXPECT_EQ(assert.arg0_data(), assert.arg_data(0));
|
||||
EXPECT_EQ(assert.arg1_data(), assert.arg_data(1));
|
||||
|
||||
assert.arg0() = 2;
|
||||
assert.arg1() = 1;
|
||||
|
@ -160,6 +160,7 @@ cc_library(
|
||||
"//tensorflow/compiler/jit/ops:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:dump_graph",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -178,6 +179,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:constant_op",
|
||||
"//tensorflow/core/kernels:control_flow_ops",
|
||||
"//tensorflow/core/kernels:fifo_queue",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
"//tensorflow/core/kernels:identity_n_op",
|
||||
"//tensorflow/core/kernels:identity_op",
|
||||
"//tensorflow/core/kernels:no_op",
|
||||
@ -186,6 +188,9 @@ cc_library(
|
||||
"//tensorflow/core/kernels:sendrecv_ops",
|
||||
"//tensorflow/core/kernels:shape_ops",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
"//tensorflow/core/kernels/data:generator_dataset_op",
|
||||
"//tensorflow/core/kernels/data:iterator_ops",
|
||||
"//tensorflow/core/kernels/data:prefetch_dataset_op",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -46,6 +46,7 @@ class Predicate {
|
||||
|
||||
virtual string ToString() const = 0;
|
||||
int64 hash() const { return hash_; }
|
||||
virtual gtl::ArraySlice<Predicate*> GetOperands() const = 0;
|
||||
|
||||
virtual Kind kind() const = 0;
|
||||
virtual ~Predicate() {}
|
||||
@ -90,7 +91,8 @@ class AndPredicate : public Predicate {
|
||||
|
||||
Kind kind() const override { return Kind::kAnd; }
|
||||
|
||||
const gtl::ArraySlice<Predicate*> operands() const { return operands_; }
|
||||
gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
|
||||
gtl::ArraySlice<Predicate*> operands() const { return operands_; }
|
||||
|
||||
private:
|
||||
std::vector<Predicate*> operands_;
|
||||
@ -117,7 +119,8 @@ class OrPredicate : public Predicate {
|
||||
}
|
||||
|
||||
Kind kind() const override { return Kind::kOr; }
|
||||
const gtl::ArraySlice<Predicate*> operands() const { return operands_; }
|
||||
gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
|
||||
gtl::ArraySlice<Predicate*> operands() const { return operands_; }
|
||||
|
||||
private:
|
||||
std::vector<Predicate*> operands_;
|
||||
@ -128,17 +131,18 @@ class NotPredicate : public Predicate {
|
||||
public:
|
||||
explicit NotPredicate(Predicate* operand)
|
||||
: Predicate(HashPredicateSequence(Kind::kNot, {operand})),
|
||||
operand_(operand) {}
|
||||
operands_({operand}) {}
|
||||
|
||||
string ToString() const override {
|
||||
return strings::StrCat("~", operand()->ToString());
|
||||
}
|
||||
|
||||
Kind kind() const override { return Kind::kNot; }
|
||||
Predicate* operand() const { return operand_; }
|
||||
Predicate* operand() const { return operands_[0]; }
|
||||
gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
|
||||
|
||||
private:
|
||||
Predicate* operand_;
|
||||
std::array<Predicate*, 1> operands_;
|
||||
};
|
||||
|
||||
// Represents an uninterpreted symbol in a logical predicate.
|
||||
@ -158,6 +162,7 @@ class SymbolPredicate : public Predicate {
|
||||
}
|
||||
|
||||
Kind kind() const override { return Kind::kSymbol; }
|
||||
gtl::ArraySlice<Predicate*> GetOperands() const override { return {}; }
|
||||
|
||||
// If `must_be_true()` is true this SymbolPredicate represents the proposition
|
||||
// "tensor_id() is live and evaluates to true".
|
||||
@ -288,10 +293,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
|
||||
|
||||
if (op->kind() == pred_kind) {
|
||||
// "Inline" the operands of an inner And/Or into the parent And/Or.
|
||||
gtl::ArraySlice<Predicate*> operands =
|
||||
is_and ? dynamic_cast<AndPredicate*>(op)->operands()
|
||||
: dynamic_cast<OrPredicate*>(op)->operands();
|
||||
for (Predicate* subop : operands) {
|
||||
for (Predicate* subop : op->GetOperands()) {
|
||||
if (simplified_ops_set.insert(subop).second) {
|
||||
simplified_ops.push_back(subop);
|
||||
}
|
||||
|
@ -1161,8 +1161,7 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef(
|
||||
strings::StrCat("replace_encapsulate_fdef_", name), fdef);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(library->RemoveFunction(name));
|
||||
TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
|
||||
TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -16,6 +16,7 @@ cc_library(
|
||||
"//tensorflow/compiler/jit:xla_device",
|
||||
"//tensorflow/compiler/jit:xla_launch_util",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
@ -199,7 +200,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
run_options.set_stream(stream);
|
||||
run_options.set_allocator(xla_allocator);
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
run_options.set_rng_seed(ctx->step_id());
|
||||
run_options.set_rng_seed(GetXLARandomSeed());
|
||||
Env* env = Env::Default();
|
||||
auto start_time = env->NowMicros();
|
||||
|
||||
|
@ -296,7 +296,7 @@ Status XlaCompilationCache::CompileImpl(
|
||||
// protect the contents of the cache entry.
|
||||
Entry* entry;
|
||||
{
|
||||
mutex_lock lock(mu_);
|
||||
mutex_lock lock(compile_cache_mu_);
|
||||
// Find or create a cache entry.
|
||||
std::unique_ptr<Entry>& e = cache_[signature];
|
||||
if (!e) {
|
||||
@ -312,6 +312,8 @@ Status XlaCompilationCache::CompileImpl(
|
||||
if (!entry->compiled) {
|
||||
VLOG(1) << "Compilation cache miss for signature: "
|
||||
<< SignatureDebugString(signature);
|
||||
tensorflow::Env* env = tensorflow::Env::Default();
|
||||
const uint64 compile_start_us = env->NowMicros();
|
||||
// Do the actual JIT compilation without holding the lock (it can take
|
||||
// a long time.)
|
||||
std::vector<XlaCompiler::Argument> args;
|
||||
@ -334,6 +336,26 @@ Status XlaCompilationCache::CompileImpl(
|
||||
CHECK_EQ(entry->executable.get(), nullptr);
|
||||
entry->compilation_status =
|
||||
BuildExecutable(options, entry->compilation_result, &entry->executable);
|
||||
|
||||
const uint64 compile_end_us = env->NowMicros();
|
||||
const uint64 compile_time_us = compile_end_us - compile_start_us;
|
||||
{
|
||||
mutex_lock lock(compile_stats_mu_);
|
||||
auto it = compile_stats_.emplace(function.name(), CompileStats{}).first;
|
||||
it->second.compile_count++;
|
||||
it->second.cumulative_compile_time_us += compile_time_us;
|
||||
VLOG(1) << "compiled " << function.name() << " "
|
||||
<< it->second.compile_count
|
||||
<< " times, compile time: " << compile_time_us
|
||||
<< " us, cumulative: " << it->second.cumulative_compile_time_us
|
||||
<< " us ("
|
||||
<< tensorflow::strings::HumanReadableElapsedTime(compile_time_us /
|
||||
1.0e6)
|
||||
<< " / "
|
||||
<< tensorflow::strings::HumanReadableElapsedTime(
|
||||
it->second.cumulative_compile_time_us / 1.0e6)
|
||||
<< ")";
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(entry->compilation_status);
|
||||
*compilation_result = &entry->compilation_result;
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
|
||||
@ -150,9 +151,22 @@ class XlaCompilationCache : public ResourceBase {
|
||||
std::unique_ptr<xla::LocalExecutable> executable GUARDED_BY(mu);
|
||||
};
|
||||
|
||||
mutex mu_;
|
||||
std::unordered_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
|
||||
GUARDED_BY(mu_);
|
||||
mutex compile_cache_mu_;
|
||||
gtl::FlatMap<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
|
||||
GUARDED_BY(compile_cache_mu_);
|
||||
|
||||
struct CompileStats {
|
||||
// Number of times the cluster has been (re-)compiled.
|
||||
int64 compile_count = 0;
|
||||
|
||||
// Cumulative time spent compiling the cluster.
|
||||
int64 cumulative_compile_time_us = 0;
|
||||
};
|
||||
mutex compile_stats_mu_;
|
||||
|
||||
// Maps cluster names to compilation statistics for said cluster.
|
||||
gtl::FlatMap<string, CompileStats> compile_stats_
|
||||
GUARDED_BY(compile_stats_mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
|
||||
};
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
|
||||
@ -71,7 +72,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
run_options.set_stream(stream);
|
||||
run_options.set_allocator(client->backend().memory_allocator());
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
run_options.set_rng_seed(ctx->step_id());
|
||||
run_options.set_rng_seed(GetXLARandomSeed());
|
||||
|
||||
xla::StatusOr<xla::ScopedShapedBuffer> run_result =
|
||||
executable->Run(launch_context.arguments(), run_options);
|
||||
|
@ -211,17 +211,18 @@ XlaDevice::XlaDevice(
|
||||
use_multiple_streams),
|
||||
device_ordinal_(device_ordinal),
|
||||
jit_device_name_(jit_device_name),
|
||||
xla_allocator_(nullptr),
|
||||
platform_(platform),
|
||||
use_multiple_streams_(use_multiple_streams),
|
||||
transfer_as_literal_(transfer_as_literal),
|
||||
shape_representation_fn_(shape_representation_fn) {
|
||||
VLOG(1) << "Created XLA device " << jit_device_name;
|
||||
VLOG(1) << "Created XLA device " << jit_device_name << " " << this;
|
||||
}
|
||||
|
||||
XlaDevice::~XlaDevice() {
|
||||
if (gpu_device_info_ != nullptr) {
|
||||
gpu_device_info_->default_context->Unref();
|
||||
VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this;
|
||||
mutex_lock lock(mu_);
|
||||
if (device_context_) {
|
||||
device_context_->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
@ -237,6 +238,11 @@ xla::LocalClient* XlaDevice::client() const {
|
||||
}
|
||||
|
||||
Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
|
||||
mutex_lock lock(mu_);
|
||||
return GetAllocatorLocked(attr);
|
||||
}
|
||||
|
||||
Allocator* XlaDevice::GetAllocatorLocked(AllocatorAttributes attr) {
|
||||
if (attr.on_host()) {
|
||||
return cpu_allocator();
|
||||
}
|
||||
@ -249,83 +255,105 @@ Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
|
||||
return xla_allocator_;
|
||||
}
|
||||
|
||||
xla::StatusOr<se::Stream*> XlaDevice::GetStream() {
|
||||
if (!stream_) {
|
||||
xla::Backend* backend = client()->mutable_backend();
|
||||
TF_ASSIGN_OR_RETURN(stream_, backend->BorrowStream(device_ordinal_));
|
||||
}
|
||||
return stream_.get();
|
||||
Status XlaDevice::EnsureDeviceContextOk() {
|
||||
mutex_lock lock(mu_);
|
||||
return GetDeviceContextLocked().status();
|
||||
}
|
||||
|
||||
xla::StatusOr<se::Stream*> XlaDevice::GetDeviceToHostStream() {
|
||||
if (!use_multiple_streams_) {
|
||||
return GetStream();
|
||||
Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend,
|
||||
const string& name,
|
||||
xla::StreamPool::Ptr* stream,
|
||||
bool* stream_was_changed) {
|
||||
if (!(*stream) || !(*stream)->ok()) {
|
||||
TF_ASSIGN_OR_RETURN(*stream, backend->BorrowStream(device_ordinal_));
|
||||
VLOG(1) << "XlaDevice " << this << " new " << name << " "
|
||||
<< (*stream)->DebugStreamPointers();
|
||||
*stream_was_changed = true;
|
||||
}
|
||||
if (!device_to_host_stream_) {
|
||||
xla::Backend* backend = client()->mutable_backend();
|
||||
TF_ASSIGN_OR_RETURN(device_to_host_stream_,
|
||||
backend->BorrowStream(device_ordinal_));
|
||||
}
|
||||
return device_to_host_stream_.get();
|
||||
}
|
||||
|
||||
xla::StatusOr<se::Stream*> XlaDevice::GetHostToDeviceStream() {
|
||||
if (!use_multiple_streams_) {
|
||||
return GetStream();
|
||||
}
|
||||
if (!host_to_device_stream_) {
|
||||
xla::Backend* backend = client()->mutable_backend();
|
||||
TF_ASSIGN_OR_RETURN(host_to_device_stream_,
|
||||
backend->BorrowStream(device_ordinal_));
|
||||
}
|
||||
return host_to_device_stream_.get();
|
||||
}
|
||||
|
||||
Status XlaDevice::CreateAndSetGpuDeviceInfo() {
|
||||
if (gpu_device_info_ == nullptr) {
|
||||
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
|
||||
// Call GetAllocator for the side-effect of ensuring the allocator
|
||||
// is created.
|
||||
GetAllocator({});
|
||||
// XlaDevice owns both gpu_device_info_ and
|
||||
// gpu_device_info_->default_context.
|
||||
gpu_device_info_ = MakeUnique<GpuDeviceInfo>();
|
||||
gpu_device_info_->stream = stream;
|
||||
gpu_device_info_->default_context =
|
||||
new XlaDeviceContext(stream, stream, stream, client(),
|
||||
transfer_as_literal_, shape_representation_fn_);
|
||||
set_tensorflow_gpu_device_info(gpu_device_info_.get());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
|
||||
xla::Backend* backend = client()->mutable_backend();
|
||||
|
||||
// Ensure all our streams are valid, borrowing new streams if necessary.
|
||||
bool need_new_device_context = !device_context_;
|
||||
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
|
||||
&need_new_device_context));
|
||||
|
||||
se::Stream* host_to_device_stream = stream_.get();
|
||||
se::Stream* device_to_host_stream = stream_.get();
|
||||
if (use_multiple_streams_) {
|
||||
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
|
||||
&host_to_device_stream_,
|
||||
&need_new_device_context));
|
||||
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream",
|
||||
&device_to_host_stream_,
|
||||
&need_new_device_context));
|
||||
host_to_device_stream = host_to_device_stream_.get();
|
||||
device_to_host_stream = device_to_host_stream_.get();
|
||||
}
|
||||
|
||||
if (!need_new_device_context) {
|
||||
return device_context_;
|
||||
}
|
||||
|
||||
// At this point we know we need a new device context.
|
||||
// Call GetAllocator for the side-effect of ensuring the allocator is created.
|
||||
GetAllocatorLocked({});
|
||||
if (device_context_) {
|
||||
device_context_->Unref();
|
||||
}
|
||||
device_context_ = new XlaDeviceContext(
|
||||
stream_.get(), host_to_device_stream, device_to_host_stream, client(),
|
||||
transfer_as_literal_, shape_representation_fn_);
|
||||
VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext "
|
||||
<< device_context_;
|
||||
|
||||
// Create and set a new GpuDeviceInfo, if necessary.
|
||||
//
|
||||
// TODO(b/78232898): This isn't thread-safe; there is a race between the call
|
||||
// to set_tensorflow_gpu_device_info() with ops that call the getter
|
||||
// tensorflow_gpu_device_info(). This isn't trivially fixed by adding locking
|
||||
// to those methods; see the bug for details. Our only saving grace at the
|
||||
// moment is that this race doesn't seem to occur in practice.
|
||||
if (use_gpu_device_info_) {
|
||||
auto gpu_device_info = MakeUnique<GpuDeviceInfo>();
|
||||
gpu_device_info->stream = stream_.get();
|
||||
gpu_device_info->default_context = device_context_;
|
||||
set_tensorflow_gpu_device_info(gpu_device_info.get());
|
||||
gpu_device_info_ = std::move(gpu_device_info);
|
||||
VLOG(1) << "XlaDevice " << this << " new GpuDeviceInfo "
|
||||
<< gpu_device_info_.get();
|
||||
}
|
||||
|
||||
return device_context_;
|
||||
}
|
||||
|
||||
Status XlaDevice::UseGpuDeviceInfo() {
|
||||
mutex_lock lock(mu_);
|
||||
use_gpu_device_info_ = true;
|
||||
return GetDeviceContextLocked().status();
|
||||
}
|
||||
|
||||
Status XlaDevice::FillContextMap(const Graph* graph,
|
||||
DeviceContextMap* device_context_map) {
|
||||
VLOG(1) << "XlaDevice::FillContextMap";
|
||||
device_context_map->resize(graph->num_node_ids());
|
||||
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
|
||||
TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
|
||||
GetDeviceToHostStream());
|
||||
TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
|
||||
GetHostToDeviceStream());
|
||||
mutex_lock lock(mu_);
|
||||
TF_ASSIGN_OR_RETURN(XlaDeviceContext * device_context,
|
||||
GetDeviceContextLocked());
|
||||
|
||||
// Call GetAllocator for the side-effect of ensuring the allocator is created.
|
||||
GetAllocator({});
|
||||
auto ctx = new XlaDeviceContext(
|
||||
stream, host_to_device_stream, device_to_host_stream, client(),
|
||||
transfer_as_literal_, shape_representation_fn_);
|
||||
device_context_map->resize(graph->num_node_ids());
|
||||
for (Node* n : graph->nodes()) {
|
||||
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
|
||||
ctx->Ref();
|
||||
(*device_context_map)[n->id()] = ctx;
|
||||
device_context->Ref();
|
||||
(*device_context_map)[n->id()] = device_context;
|
||||
}
|
||||
ctx->Unref();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
|
||||
VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":"
|
||||
VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
|
||||
<< op_kernel->type_string();
|
||||
// When Xprof profiling is off (which is the default), constructing the
|
||||
// activity is simple enough that its overhead is negligible.
|
||||
@ -336,7 +364,7 @@ void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
|
||||
|
||||
void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
|
||||
AsyncOpKernel::DoneCallback done) {
|
||||
VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
|
||||
VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
|
||||
<< op_kernel->type_string();
|
||||
tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(),
|
||||
op_kernel->IsExpensive());
|
||||
@ -358,17 +386,13 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
if (alloc_attrs.on_host()) {
|
||||
*tensor = parsed;
|
||||
} else {
|
||||
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
|
||||
mutex_lock lock(mu_);
|
||||
TF_ASSIGN_OR_RETURN(XlaDeviceContext * device_context,
|
||||
GetDeviceContextLocked());
|
||||
Allocator* allocator = GetAllocatorLocked(alloc_attrs);
|
||||
Tensor copy(allocator, parsed.dtype(), parsed.shape());
|
||||
Notification n;
|
||||
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
|
||||
TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
|
||||
GetDeviceToHostStream());
|
||||
TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
|
||||
GetHostToDeviceStream());
|
||||
XlaTransferManager manager(stream, host_to_device_stream,
|
||||
device_to_host_stream, client(),
|
||||
transfer_as_literal_, shape_representation_fn_);
|
||||
manager.CopyCPUTensorToDevice(&parsed, this, ©,
|
||||
device_context->CopyCPUTensorToDevice(&parsed, this, ©,
|
||||
[&n, &status](const Status& s) {
|
||||
status = s;
|
||||
n.Notify();
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device_context.h"
|
||||
#include "tensorflow/compiler/jit/xla_tensor.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
@ -40,6 +41,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -117,62 +119,85 @@ class XlaDevice : public LocalDevice {
|
||||
const PaddedShapeFn& padded_shape_fn);
|
||||
~XlaDevice() override;
|
||||
|
||||
Allocator* GetAllocator(AllocatorAttributes attr) override;
|
||||
Allocator* GetAllocator(AllocatorAttributes attr) override
|
||||
LOCKS_EXCLUDED(mu_);
|
||||
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
|
||||
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
|
||||
AsyncOpKernel::DoneCallback done) override;
|
||||
Status Sync() override { return Status::OK(); }
|
||||
|
||||
Status FillContextMap(const Graph* graph,
|
||||
DeviceContextMap* device_context_map) override;
|
||||
DeviceContextMap* device_context_map) override
|
||||
LOCKS_EXCLUDED(mu_);
|
||||
|
||||
Status MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
const AllocatorAttributes alloc_attrs,
|
||||
Tensor* tensor) override;
|
||||
Tensor* tensor) override LOCKS_EXCLUDED(mu_);
|
||||
|
||||
xla::LocalClient* client() const;
|
||||
const Metadata& metadata() { return xla_metadata_; }
|
||||
xla::StatusOr<se::Stream*> GetStream();
|
||||
xla::StatusOr<se::Stream*> GetHostToDeviceStream();
|
||||
xla::StatusOr<se::Stream*> GetDeviceToHostStream();
|
||||
|
||||
// If not already set, create and set GpuDeviceInfo.
|
||||
// Not thread-safe
|
||||
Status CreateAndSetGpuDeviceInfo();
|
||||
// Ensures the DeviceContext associated with this XlaDevice is created and
|
||||
// valid (i.e. all streams are ok). If any state is not valid, a new
|
||||
// DeviceContext will be created.
|
||||
//
|
||||
// TODO(b/111859745): The Eager context needs to call this method to recover
|
||||
// from failures.
|
||||
Status EnsureDeviceContextOk() LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra
|
||||
// information for GPU and TPU devices.
|
||||
Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_);
|
||||
|
||||
private:
|
||||
xla::LocalClient* client() const;
|
||||
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
|
||||
xla::StreamPool::Ptr* stream,
|
||||
bool* stream_was_changed)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked()
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
mutex mu_;
|
||||
// The metadata of this XlaDevice.
|
||||
const Metadata xla_metadata_;
|
||||
// Which hardware device in the client's platform this XlaDevice controls.
|
||||
const int device_ordinal_;
|
||||
// The name of the device that is used to compile Ops for this XlaDevice.
|
||||
DeviceType jit_device_name_;
|
||||
const DeviceType jit_device_name_;
|
||||
// The platform for this device.
|
||||
se::Platform* const platform_; // Not owned.
|
||||
// Memory allocator associated with this device.
|
||||
Allocator* xla_allocator_; // Not owned.
|
||||
se::Platform* platform_; // Not owned.
|
||||
Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned.
|
||||
// Stream associated with this device. Operations enqueued on this
|
||||
// stream are executed on the device. Operations include data
|
||||
// copying back and forth between CPU and the device, and
|
||||
// computations enqueued by XLA.
|
||||
xla::StreamPool::Ptr stream_;
|
||||
// If true, only stream_ is valid and all computation and transfers use
|
||||
// stream_. If false, computation is performed by stream_ and transfers are
|
||||
xla::StreamPool::Ptr stream_ GUARDED_BY(mu_);
|
||||
// If false, only stream_ is valid and all computation and transfers use
|
||||
// stream_. If true, computation is performed by stream_ and transfers are
|
||||
// performed by host_to_device/device_to_host_stream.
|
||||
bool use_multiple_streams_;
|
||||
const bool use_multiple_streams_;
|
||||
// If use_multiple_streams_, host to device transfers are performed using this
|
||||
// stream.
|
||||
xla::StreamPool::Ptr host_to_device_stream_;
|
||||
xla::StreamPool::Ptr host_to_device_stream_ GUARDED_BY(mu_);
|
||||
// If use_multiple_streams_, device to host transfers are performed using this
|
||||
// stream.
|
||||
xla::StreamPool::Ptr device_to_host_stream_;
|
||||
xla::StreamPool::Ptr device_to_host_stream_ GUARDED_BY(mu_);
|
||||
// Must we use XLA's transfer manager for correct host<->device transfers? if
|
||||
// false, we can use ThenMemcpy() instead.
|
||||
bool transfer_as_literal_;
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
||||
const bool transfer_as_literal_;
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
||||
|
||||
// If set, holds default device context (that we must Unref)
|
||||
// and its stream.
|
||||
std::unique_ptr<GpuDeviceInfo> gpu_device_info_;
|
||||
// The device context accessed by all users of the XlaDevice, set by calls to
|
||||
// EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is
|
||||
// also filled in to that struct. XlaDeviceContext is a ref-counted object.
|
||||
XlaDeviceContext* device_context_ GUARDED_BY(mu_) = nullptr;
|
||||
|
||||
// Holds extra information for GPU and TPU devices, e.g. the device context.
|
||||
bool use_gpu_device_info_ GUARDED_BY(mu_) = false;
|
||||
std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
// Builds OpKernel registrations on 'device' for the JIT operators
|
||||
|
@ -101,34 +101,27 @@ Status XlaTransferManager::TransferLiteralToDevice(
|
||||
// Unref the host tensor, and capture the literal shared_ptr too so it goes
|
||||
// out of scope when the lambda completes.
|
||||
host_to_device_stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); });
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void XlaTransferManager::TransferLiteralFromDevice(
|
||||
Tensor* host_tensor, const Tensor& device_tensor,
|
||||
const StatusCallback& done) const {
|
||||
xla::MutableBorrowingLiteral literal;
|
||||
TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(host_tensor, &literal));
|
||||
|
||||
const xla::ShapedBuffer& shaped_buffer =
|
||||
XlaTensor::FromTensor(&device_tensor)->shaped_buffer();
|
||||
|
||||
TensorReference ref(device_tensor);
|
||||
transfer_manager_->TransferLiteralFromDevice(
|
||||
device_to_host_stream_, shaped_buffer,
|
||||
[=, &shaped_buffer](
|
||||
xla::StatusOr<std::unique_ptr<xla::Literal> > literal_or) {
|
||||
device_to_host_stream_, shaped_buffer, literal,
|
||||
[=, &shaped_buffer, &literal](xla::Status status) {
|
||||
ref.Unref();
|
||||
done([&]() -> Status {
|
||||
TF_ASSIGN_OR_RETURN(auto literal, std::move(literal_or));
|
||||
VLOG(1) << "Transfer from device as literal: " << literal->ToString()
|
||||
VLOG(1) << "Transfer from device as literal: " << literal.ToString()
|
||||
<< " " << shaped_buffer.ToString();
|
||||
Tensor tensor;
|
||||
TF_RETURN_IF_ERROR(
|
||||
LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
|
||||
// Reshape the tensor back to its declared shape.
|
||||
Status status;
|
||||
if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) {
|
||||
status = errors::Internal(
|
||||
"Tensor::CopyFrom failed when copying from XLA device to CPU");
|
||||
}
|
||||
return status;
|
||||
}());
|
||||
});
|
||||
|
@ -23,7 +23,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/cast_op.h"
|
||||
#include "tensorflow/core/kernels/constant_op.h"
|
||||
#include "tensorflow/core/kernels/control_flow_ops.h"
|
||||
#include "tensorflow/core/kernels/data/generator_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/data/iterator_ops.h"
|
||||
#include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/fifo_queue.h"
|
||||
#include "tensorflow/core/kernels/function_ops.h"
|
||||
#include "tensorflow/core/kernels/identity_n_op.h"
|
||||
#include "tensorflow/core/kernels/identity_op.h"
|
||||
#include "tensorflow/core/kernels/no_op.h"
|
||||
@ -166,7 +170,69 @@ class XlaAssignVariableOp : public AsyncOpKernel {
|
||||
QueueIsClosedOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp);
|
||||
Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name(kArgOp).Device(DEVICE).HostMemory("output").TypeConstraint("T", \
|
||||
TYPES), \
|
||||
ArgOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name(kArgOp) \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("output") \
|
||||
.TypeConstraint<ResourceHandle>("T"), \
|
||||
ArgOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER(Name(kRetOp) \
|
||||
.Device(DEVICE) \
|
||||
.TypeConstraint("T", TYPES) \
|
||||
.HostMemory("input"), \
|
||||
RetvalOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name(kRetOp) \
|
||||
.Device(DEVICE) \
|
||||
.TypeConstraint<ResourceHandle>("T") \
|
||||
.HostMemory("input"), \
|
||||
RetvalOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("RemoteCall").Device(DEVICE).HostMemory("target"), RemoteCallOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"), \
|
||||
GeneratorDatasetOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("buffer_size") \
|
||||
.HostMemory("input_dataset") \
|
||||
.HostMemory("handle"), \
|
||||
PrefetchDatasetOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE), \
|
||||
IteratorHandleOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MakeIterator").Device(DEVICE).HostMemory("dataset"), \
|
||||
MakeIteratorOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \
|
||||
AnonymousIteratorHandleOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
|
||||
IteratorGetNextOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("string_handle"), \
|
||||
IteratorToStringHandleOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("string_handle"), \
|
||||
IteratorFromStringHandleOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("output") \
|
||||
.TypeConstraint<string>("T"), \
|
||||
ArgOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kRetOp) \
|
||||
.Device(DEVICE) \
|
||||
.TypeConstraint<string>("T") \
|
||||
.HostMemory("input"), \
|
||||
RetvalOp);
|
||||
|
||||
// TODO(phawkins): currently we do not register the QueueEnqueueMany,
|
||||
// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read
|
||||
|
@ -59,7 +59,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
|
||||
}
|
||||
|
||||
// TODO(b/78468222): Uncomment after fixing this bug
|
||||
// status = device->CreateAndSetGpuDeviceInfo();
|
||||
// status = device->UseGpuDeviceInfo();
|
||||
// if (!status.ok()) {
|
||||
// errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT,
|
||||
// " device");
|
||||
|
@ -691,11 +691,7 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["random_ops_test.py"],
|
||||
disabled_backends = [
|
||||
# TODO(b/110300529): RngNormal doesn't return values with the expected variance
|
||||
"cpu",
|
||||
"cpu_ondemand",
|
||||
# TODO(b/31361304): enable RNG ops on GPU when parallelized.
|
||||
"gpu",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
|
@ -52,6 +52,9 @@ class AdamOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
for dtype in self.float_types:
|
||||
# TODO: test fails for float16 due to excessive precision requirements.
|
||||
if dtype == np.float16:
|
||||
continue
|
||||
with self.test_session(), self.test_scope():
|
||||
variable_scope.get_variable_scope().set_use_resource(True)
|
||||
|
||||
@ -91,6 +94,9 @@ class AdamOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testTensorLearningRate(self):
|
||||
for dtype in self.float_types:
|
||||
# TODO: test fails for float16 due to excessive precision requirements.
|
||||
if dtype == np.float16:
|
||||
continue
|
||||
with self.test_session(), self.test_scope():
|
||||
variable_scope.get_variable_scope().set_use_resource(True)
|
||||
|
||||
@ -130,6 +136,9 @@ class AdamOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testSharing(self):
|
||||
for dtype in self.float_types:
|
||||
# TODO: test fails for float16 due to excessive precision requirements.
|
||||
if dtype == np.float16:
|
||||
continue
|
||||
with self.test_session(), self.test_scope():
|
||||
variable_scope.get_variable_scope().set_use_resource(True)
|
||||
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.layers import convolutional
|
||||
from tensorflow.python.layers import pooling
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import gen_random_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
@ -122,6 +123,14 @@ class EagerTest(xla_test.XLATestCase):
|
||||
with self.test_scope():
|
||||
self.assertAllEqual(2, array_ops.identity(2))
|
||||
|
||||
def testRandomOps(self):
|
||||
with self.test_scope():
|
||||
tensor = gen_random_ops.random_uniform((2, 2), dtypes.float32)
|
||||
row0 = tensor[0].numpy()
|
||||
row1 = tensor[1].numpy()
|
||||
# It should be very unlikely to rng to generate two equal rows.
|
||||
self.assertFalse((row0 == row1).all())
|
||||
|
||||
def testIdentityOnVariable(self):
|
||||
with self.test_scope():
|
||||
v = resource_variable_ops.ResourceVariable(True)
|
||||
|
@ -57,7 +57,8 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
def testRandomUniformIsNotConstant(self):
|
||||
|
||||
def rng(dtype):
|
||||
return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000)
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=dtype.max)
|
||||
|
||||
for dtype in self._random_types():
|
||||
self._testRngIsNotConstant(rng, dtype)
|
||||
@ -73,6 +74,11 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
|
||||
def testRandomUniformIsInRange(self):
|
||||
for dtype in self._random_types():
|
||||
# TODO (b/112272078): enable bfloat16 for CPU and GPU when the bug is
|
||||
# fixed.
|
||||
if (self.device in ["XLA_GPU", "XLA_CPU"
|
||||
]) and (dtype in [dtypes.bfloat16, dtypes.half]):
|
||||
continue
|
||||
with self.test_session() as sess:
|
||||
with self.test_scope():
|
||||
x = random_ops.random_uniform(
|
||||
@ -95,7 +101,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
for dtype in [dtypes.float32]:
|
||||
with self.test_session() as sess:
|
||||
with self.test_scope():
|
||||
x = random_ops.truncated_normal(shape=[count], dtype=dtype, seed=42)
|
||||
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
|
||||
y = sess.run(x)
|
||||
|
||||
def normal_cdf(x):
|
||||
@ -124,20 +130,23 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
# Department of Scientific Computing website. Florida State University.
|
||||
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
|
||||
actual_mean = np.mean(y)
|
||||
self.assertAllClose(actual_mean, expected_mean, atol=2e-4)
|
||||
self.assertAllClose(actual_mean, expected_mean, atol=2e-3)
|
||||
|
||||
expected_median = mu + probit(
|
||||
(normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
|
||||
actual_median = np.median(y)
|
||||
self.assertAllClose(actual_median, expected_median, atol=8e-4)
|
||||
self.assertAllClose(actual_median, expected_median, atol=1e-2)
|
||||
|
||||
expected_variance = sigma**2 * (1 + (
|
||||
(alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
|
||||
(normal_pdf(alpha) - normal_pdf(beta)) / z)**2)
|
||||
actual_variance = np.var(y)
|
||||
self.assertAllClose(actual_variance, expected_variance, rtol=3e-4)
|
||||
self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3)
|
||||
|
||||
def testShuffle1d(self):
|
||||
# TODO(b/26783907): this test requires the CPU backend to implement sort.
|
||||
if self.device in ["XLA_CPU"]:
|
||||
return
|
||||
with self.test_session() as sess:
|
||||
with self.test_scope():
|
||||
x = math_ops.range(1 << 16)
|
||||
|
@ -361,6 +361,12 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([[-0.05, 6.05, 5]], dtype=dtype),
|
||||
expected=np.array([[0, 6, 5]], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.softmax,
|
||||
np.array([1, 2, 3, 4], dtype=dtype),
|
||||
expected=np.array([0.032058604, 0.087144323, 0.23688284, 0.64391428],
|
||||
dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.softmax,
|
||||
np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype),
|
||||
@ -369,6 +375,14 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
[0.032058604, 0.087144323, 0.23688284, 0.64391428]],
|
||||
dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.softmax,
|
||||
np.array([[[1, 1], [1, 1]], [[1, 2], [3, 4]]], dtype=dtype),
|
||||
expected=np.array(
|
||||
[[[0.5, 0.5], [0.5, 0.5]],
|
||||
[[0.26894142, 0.73105858], [0.26894142, 0.73105858]]],
|
||||
dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.softsign,
|
||||
np.array([[-2, -1, 0, 1, 2]], dtype=dtype),
|
||||
|
@ -21,6 +21,8 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_control_flow_ops
|
||||
@ -47,6 +49,34 @@ class XlaDeviceTest(xla_test.XLATestCase):
|
||||
result = sess.run(z, {x: inputs})
|
||||
self.assertAllCloseAccordingToType(result, inputs + inputs)
|
||||
|
||||
def testCopiesOfUnsupportedTypesFailGracefully(self):
|
||||
"""Tests that copies of unsupported types don't crash."""
|
||||
test_types = set([
|
||||
np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.int32,
|
||||
np.int64, np.float16, np.float32, np.float16,
|
||||
dtypes.bfloat16.as_numpy_dtype
|
||||
])
|
||||
shape = (10, 10)
|
||||
for unsupported_dtype in test_types - self.all_types:
|
||||
with self.test_session() as sess:
|
||||
with ops.device("CPU"):
|
||||
x = array_ops.placeholder(unsupported_dtype, shape)
|
||||
with self.test_scope():
|
||||
y, = array_ops.identity_n([x])
|
||||
with ops.device("CPU"):
|
||||
z = array_ops.identity(y)
|
||||
|
||||
inputs = np.random.randint(-100, 100, shape)
|
||||
inputs = inputs.astype(unsupported_dtype)
|
||||
# Execution should either succeed or raise an InvalidArgumentError,
|
||||
# but not crash. Even "unsupported types" may succeed here since some
|
||||
# backends (e.g., the CPU backend) are happy to handle buffers of
|
||||
# unsupported types, even if they cannot compute with them.
|
||||
try:
|
||||
sess.run(z, {x: inputs})
|
||||
except errors.InvalidArgumentError:
|
||||
pass
|
||||
|
||||
def testControlTrigger(self):
|
||||
with self.test_session() as sess:
|
||||
with self.test_scope():
|
||||
|
@ -95,6 +95,10 @@ cc_library(
|
||||
name = "cpu_function_runtime",
|
||||
srcs = ["cpu_function_runtime.cc"],
|
||||
hdrs = ["cpu_function_runtime.h"],
|
||||
visibility = [
|
||||
"//tensorflow/compiler/aot:__pkg__",
|
||||
"//tensorflow/compiler/xla/service/cpu:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
# Keep dependencies to a minimum here; this library is used in every AOT
|
||||
# binary produced by tfcompile.
|
||||
@ -144,6 +148,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
"//tensorflow/compiler/xla/service/cpu:buffer_info_util",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_executable",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -55,19 +55,26 @@ size_t align_to(size_t n, size_t align) {
|
||||
} // namespace
|
||||
|
||||
namespace cpu_function_runtime {
|
||||
size_t AlignedBufferBytes(const intptr_t* sizes, size_t n) {
|
||||
size_t AlignedBufferBytes(const BufferInfo* buffer_infos, size_t n,
|
||||
bool allocate_entry_params) {
|
||||
size_t total = 0;
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
if (sizes[i] > 0) {
|
||||
total += align_to(sizes[i], kAlign);
|
||||
bool should_allocate =
|
||||
buffer_infos[i].is_temp_buffer() ||
|
||||
(buffer_infos[i].is_entry_parameter() && allocate_entry_params);
|
||||
|
||||
if (should_allocate) {
|
||||
total += align_to(buffer_infos[i].size(), kAlign);
|
||||
}
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
|
||||
void* MallocContiguousBuffers(const BufferInfo* buffer_infos, size_t n,
|
||||
bool allocate_entry_params, void** bufs,
|
||||
bool annotate_initialized) {
|
||||
const size_t total = AlignedBufferBytes(sizes, n);
|
||||
const size_t total =
|
||||
AlignedBufferBytes(buffer_infos, n, allocate_entry_params);
|
||||
void* contiguous = nullptr;
|
||||
if (total > 0) {
|
||||
contiguous = aligned_malloc(total, kAlign);
|
||||
@ -79,13 +86,14 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
|
||||
}
|
||||
uintptr_t pos = reinterpret_cast<uintptr_t>(contiguous);
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
if (sizes[i] < 0) {
|
||||
// bufs[i] is either a constant, an entry parameter or a thread local
|
||||
// allocation.
|
||||
bufs[i] = nullptr;
|
||||
} else {
|
||||
bool should_allocate =
|
||||
buffer_infos[i].is_temp_buffer() ||
|
||||
(buffer_infos[i].is_entry_parameter() && allocate_entry_params);
|
||||
if (should_allocate) {
|
||||
bufs[i] = reinterpret_cast<void*>(pos);
|
||||
pos += align_to(sizes[i], kAlign);
|
||||
pos += align_to(buffer_infos[i].size(), kAlign);
|
||||
} else {
|
||||
bufs[i] = nullptr;
|
||||
}
|
||||
}
|
||||
return contiguous;
|
||||
|
@ -18,29 +18,142 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace tensorflow {
|
||||
namespace cpu_function_runtime {
|
||||
// Stores information about one buffer used by an XLA:CPU compiled function.
|
||||
// These buffers are used for holding inputs to the computation, outputs from
|
||||
// the computation and as temporary scratch space.
|
||||
class BufferInfo {
|
||||
public:
|
||||
// Creates a BufferInfo from a serialized encoding generated by `Encode`.
|
||||
explicit BufferInfo(std::pair<uint64, uint64> encoding)
|
||||
: entry_param_number_(encoding.second) {
|
||||
Kind kind;
|
||||
uint64 size;
|
||||
Unpack(encoding.first, &kind, &size);
|
||||
kind_ = kind;
|
||||
size_ = size;
|
||||
}
|
||||
|
||||
// Returns true if this buffer stores a constant. These never need to be
|
||||
// allocated by the runtime.
|
||||
bool is_constant() const { return kind() == Kind::kConstant; }
|
||||
|
||||
// Returns true if this buffer stores an entry parameter. These may or may
|
||||
// not need to be allocated by the runtime, depending on
|
||||
// XlaCompiledCpuFunction::AllocMode.
|
||||
bool is_entry_parameter() const { return kind() == Kind::kEntryParameter; }
|
||||
|
||||
// Returns the entry parameter number of this buffer.
|
||||
uint64 entry_parameter_number() const {
|
||||
assert(is_entry_parameter());
|
||||
return entry_param_number_;
|
||||
}
|
||||
|
||||
// Returns true if this buffer is temporary scratch space required by the XLA
|
||||
// computations. These are always allocated by the runtime.
|
||||
bool is_temp_buffer() const { return kind() == Kind::kTempBuffer; }
|
||||
|
||||
// Returns true if this buffer is allocated on the C stack or into registers.
|
||||
// These buffers are never allocated by the runtime.
|
||||
bool is_on_stack_buffer() const { return kind() == Kind::kOnStackBuffer; }
|
||||
|
||||
// Returns the size for this buffer.
|
||||
uint64 size() const { return size_; }
|
||||
|
||||
// Encodes this BufferInfo into two 64 bit integers that can be used to
|
||||
// reconstruct the BufferInfo later using the constructor. We need this
|
||||
// because we use BufferInfo in places where using protocol buffers would
|
||||
// negatively impact binary size.
|
||||
std::pair<uint64, uint64> Encode() const {
|
||||
static_assert(sizeof(*this) == 16, "");
|
||||
uint64 upper = Pack(kind(), size_);
|
||||
uint64 lower = entry_param_number_;
|
||||
return {upper, lower};
|
||||
}
|
||||
|
||||
bool operator==(const BufferInfo& buffer_info) const {
|
||||
if (kind() != buffer_info.kind() || size() != buffer_info.size()) {
|
||||
return false;
|
||||
}
|
||||
return !is_entry_parameter() ||
|
||||
entry_parameter_number() == buffer_info.entry_parameter_number();
|
||||
}
|
||||
|
||||
// Factory methods:
|
||||
|
||||
static BufferInfo MakeTempBuffer(uint64 size) {
|
||||
return BufferInfo(Kind::kTempBuffer, /*size=*/size,
|
||||
/*entry_param_number=*/-1);
|
||||
}
|
||||
static BufferInfo MakeConstant(uint64 size) {
|
||||
return BufferInfo(Kind::kConstant, /*size=*/size,
|
||||
/*entry_param_number=*/-1);
|
||||
}
|
||||
static BufferInfo MakeEntryParameter(uint64 size, uint64 param_number) {
|
||||
return BufferInfo(Kind::kEntryParameter, /*size=*/size,
|
||||
/*entry_param_number=*/param_number);
|
||||
}
|
||||
static BufferInfo MakeOnStackBuffer(uint64 size) {
|
||||
return BufferInfo(Kind::kOnStackBuffer, /*size=*/size,
|
||||
/*entry_param_number=*/-1);
|
||||
}
|
||||
|
||||
private:
|
||||
BufferInfo() = default;
|
||||
|
||||
enum class Kind : unsigned {
|
||||
kConstant,
|
||||
kTempBuffer,
|
||||
kEntryParameter,
|
||||
kOnStackBuffer
|
||||
};
|
||||
|
||||
Kind kind() const { return static_cast<Kind>(kind_); }
|
||||
|
||||
explicit BufferInfo(Kind kind, uint64 size, uint64 entry_param_number)
|
||||
: kind_(kind), size_(size), entry_param_number_(entry_param_number) {}
|
||||
|
||||
static uint64 Pack(Kind kind, uint64 size) {
|
||||
return (static_cast<uint64>(size) << 2) | static_cast<uint64>(kind);
|
||||
}
|
||||
|
||||
static void Unpack(uint64 packed, Kind* kind, uint64* size) {
|
||||
*size = packed >> 2;
|
||||
*kind = static_cast<Kind>((packed << 62) >> 62);
|
||||
}
|
||||
|
||||
Kind kind_ : 2;
|
||||
uint64 size_ : 62;
|
||||
int64 entry_param_number_;
|
||||
};
|
||||
|
||||
// Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment.
|
||||
constexpr size_t kAlign = 64;
|
||||
|
||||
// AlignedBufferBytes returns the sum of each size in `sizes`, skipping -1
|
||||
// values. There are `n` entries in `sizes`. Each buffer is aligned to
|
||||
// kAlign byte boundaries.
|
||||
size_t AlignedBufferBytes(const intptr_t* sizes, size_t n);
|
||||
// AlignedBufferBytes returns the sum of the size of each buffer in
|
||||
// `buffer_infos`, skipping constants, on-stack buffers and, if
|
||||
// allocate_entry_params is false, entry parameters. There are `n` entries in
|
||||
// `buffer_infos`. Each buffer is aligned to kAlign byte boundaries.
|
||||
size_t AlignedBufferBytes(const BufferInfo* buffer_infos, size_t n,
|
||||
bool allocate_entry_params);
|
||||
|
||||
// MallocContiguousBuffers allocates buffers for use by the entry point
|
||||
// generated by tfcompile. `sizes` is an array of byte sizes for each buffer,
|
||||
// where -1 causes the buffer pointer to be nullptr. There are `n` entries in
|
||||
// `sizes`. If `annotate_initialized` is set, the allocated memory will be
|
||||
// annotated as having been initialized - this is useful when allocating
|
||||
// temporary buffers.
|
||||
// generated by tfcompile. There are `n` entries in `buffer_infos`. If
|
||||
// `annotate_initialized` is set, the allocated memory will be annotated as
|
||||
// having been initialized - this is useful when allocating temporary buffers.
|
||||
// If allocate_entry_params is true then allocates temp buffers and entry
|
||||
// parameters, otherwise allocated only temp buffers. Slots in `bufs`
|
||||
// corresponding to unallocated buffers are set to nullptr.
|
||||
//
|
||||
// A single contiguous block of memory is allocated, and portions of it are
|
||||
// parceled out into `bufs`, which must have space for `n` entries. Returns
|
||||
// the head of the allocated contiguous block, which should be passed to
|
||||
// FreeContiguous when the buffers are no longer in use.
|
||||
void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
|
||||
void* MallocContiguousBuffers(const BufferInfo* buffer_infos, size_t n,
|
||||
bool allocate_entry_params, void** bufs,
|
||||
bool annotate_initialized);
|
||||
|
||||
// FreeContiguous frees the contiguous block of memory allocated by
|
||||
|
@ -21,6 +21,8 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using cpu_function_runtime::BufferInfo;
|
||||
|
||||
TEST(XlaCompiledCpuFunctionTest, AlignmentValue) {
|
||||
// We've chosen 64 byte alignment for the tfcompile runtime to mimic the
|
||||
// regular tensorflow allocator, which was chosen to play nicely with Eigen.
|
||||
@ -30,20 +32,51 @@ TEST(XlaCompiledCpuFunctionTest, AlignmentValue) {
|
||||
EXPECT_EQ(cpu_function_runtime::kAlign, Allocator::kAllocatorAlignment);
|
||||
}
|
||||
|
||||
std::vector<BufferInfo> SizesToBufferInfos(const intptr_t* sizes, size_t n) {
|
||||
std::vector<BufferInfo> buffer_infos;
|
||||
std::transform(sizes, sizes + n, std::back_inserter(buffer_infos),
|
||||
[&](intptr_t size) {
|
||||
if (size == -1) {
|
||||
// Use a dummy on-stack buffer allocation to indicat the
|
||||
// the current slot does not need an allocation.
|
||||
int64 on_stack_buffer_size = 4;
|
||||
return BufferInfo::MakeOnStackBuffer(on_stack_buffer_size);
|
||||
}
|
||||
return BufferInfo::MakeTempBuffer(size);
|
||||
});
|
||||
return buffer_infos;
|
||||
}
|
||||
|
||||
// Simple wrappers to make writing tests more ergonomic.
|
||||
|
||||
size_t AlignedBufferBytesFromSizes(const intptr_t* sizes, size_t n) {
|
||||
std::vector<BufferInfo> buffer_infos = SizesToBufferInfos(sizes, n);
|
||||
return AlignedBufferBytes(buffer_infos.data(), n,
|
||||
/*allocate_entry_params=*/false);
|
||||
}
|
||||
|
||||
void* MallocContiguousBuffersFromSizes(const intptr_t* sizes, size_t n,
|
||||
void** bufs, bool annotate_initialized) {
|
||||
std::vector<BufferInfo> buffer_infos = SizesToBufferInfos(sizes, n);
|
||||
return MallocContiguousBuffers(buffer_infos.data(), n,
|
||||
/*allocate_entry_params=*/false, bufs,
|
||||
annotate_initialized);
|
||||
}
|
||||
|
||||
TEST(XlaCompiledCpuFunctionTest, AlignedBufferBytes) {
|
||||
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(nullptr, 0), 0);
|
||||
EXPECT_EQ(AlignedBufferBytesFromSizes(nullptr, 0), 0);
|
||||
|
||||
static constexpr intptr_t sizesA[1] = {-1};
|
||||
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesA, 1), 0);
|
||||
EXPECT_EQ(AlignedBufferBytesFromSizes(sizesA, 1), 0);
|
||||
|
||||
static constexpr intptr_t sizesB[1] = {3};
|
||||
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesB, 1), 64);
|
||||
EXPECT_EQ(AlignedBufferBytesFromSizes(sizesB, 1), 64);
|
||||
|
||||
static constexpr intptr_t sizesC[1] = {32};
|
||||
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesC, 1), 64);
|
||||
EXPECT_EQ(AlignedBufferBytesFromSizes(sizesC, 1), 64);
|
||||
|
||||
static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
|
||||
EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesD, 7), 320);
|
||||
EXPECT_EQ(AlignedBufferBytesFromSizes(sizesD, 7), 320);
|
||||
}
|
||||
|
||||
void* add_ptr(void* base, uintptr_t delta) {
|
||||
@ -56,15 +89,14 @@ void* add_ptr(void* base, uintptr_t delta) {
|
||||
// free. We also check the contiguous property.
|
||||
TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
|
||||
// Test empty sizes.
|
||||
void* base =
|
||||
cpu_function_runtime::MallocContiguousBuffers(nullptr, 0, nullptr, false);
|
||||
void* base = MallocContiguousBuffersFromSizes(nullptr, 0, nullptr, false);
|
||||
EXPECT_EQ(base, nullptr);
|
||||
cpu_function_runtime::FreeContiguous(base);
|
||||
|
||||
// Test non-empty sizes with 0 sum.
|
||||
static constexpr intptr_t sizesA[1] = {-1};
|
||||
void* bufA[1];
|
||||
base = cpu_function_runtime::MallocContiguousBuffers(sizesA, 1, bufA, false);
|
||||
base = MallocContiguousBuffersFromSizes(sizesA, 1, bufA, false);
|
||||
EXPECT_EQ(base, nullptr);
|
||||
EXPECT_EQ(bufA[0], nullptr);
|
||||
cpu_function_runtime::FreeContiguous(base);
|
||||
@ -72,7 +104,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
|
||||
// Test non-empty sizes with non-0 sum.
|
||||
static constexpr intptr_t sizesB[1] = {3};
|
||||
void* bufB[1];
|
||||
base = cpu_function_runtime::MallocContiguousBuffers(sizesB, 1, bufB, false);
|
||||
base = MallocContiguousBuffersFromSizes(sizesB, 1, bufB, false);
|
||||
EXPECT_NE(base, nullptr);
|
||||
EXPECT_EQ(bufB[0], add_ptr(base, 0));
|
||||
char* bufB0_bytes = static_cast<char*>(bufB[0]);
|
||||
@ -84,7 +116,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
|
||||
// Test non-empty sizes with non-0 sum, and annotate_initialized.
|
||||
static constexpr intptr_t sizesC[1] = {3};
|
||||
void* bufC[1];
|
||||
base = cpu_function_runtime::MallocContiguousBuffers(sizesC, 1, bufC, true);
|
||||
base = MallocContiguousBuffersFromSizes(sizesC, 1, bufC, true);
|
||||
EXPECT_NE(base, nullptr);
|
||||
EXPECT_EQ(bufC[0], add_ptr(base, 0));
|
||||
char* bufC0_bytes = static_cast<char*>(bufC[0]);
|
||||
@ -96,7 +128,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
|
||||
// Test mixed sizes.
|
||||
static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
|
||||
void* bufD[7];
|
||||
base = cpu_function_runtime::MallocContiguousBuffers(sizesD, 7, bufD, false);
|
||||
base = MallocContiguousBuffersFromSizes(sizesD, 7, bufD, false);
|
||||
EXPECT_NE(base, nullptr);
|
||||
EXPECT_EQ(bufD[0], add_ptr(base, 0));
|
||||
EXPECT_EQ(bufD[1], nullptr);
|
||||
@ -117,5 +149,23 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
|
||||
cpu_function_runtime::FreeContiguous(base);
|
||||
}
|
||||
|
||||
void CheckRoundTripIsOk(const BufferInfo& buffer_info) {
|
||||
BufferInfo round_trip(buffer_info.Encode());
|
||||
ASSERT_EQ(round_trip, buffer_info);
|
||||
}
|
||||
|
||||
TEST(XlaCompiledCpuFunctionTest, BufferInfoTest) {
|
||||
CheckRoundTripIsOk(BufferInfo::MakeTempBuffer(0));
|
||||
CheckRoundTripIsOk(BufferInfo::MakeTempBuffer(4));
|
||||
CheckRoundTripIsOk(BufferInfo::MakeOnStackBuffer(0));
|
||||
CheckRoundTripIsOk(BufferInfo::MakeOnStackBuffer(4));
|
||||
CheckRoundTripIsOk(BufferInfo::MakeConstant(0));
|
||||
CheckRoundTripIsOk(BufferInfo::MakeConstant(4));
|
||||
CheckRoundTripIsOk(
|
||||
BufferInfo::MakeEntryParameter(/*size=*/0, /*param_number=*/4));
|
||||
CheckRoundTripIsOk(
|
||||
BufferInfo::MakeEntryParameter(/*size=*/4, /*param_number=*/0));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -6,6 +6,10 @@ package(
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_copts")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
|
||||
load(
|
||||
"//third_party/mkl:build_defs.bzl",
|
||||
"if_mkl",
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "xla_ops",
|
||||
@ -129,6 +133,7 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
"//tensorflow/compiler/xla/client/lib:math",
|
||||
"//tensorflow/compiler/xla/client/lib:numeric",
|
||||
"//tensorflow/compiler/xla/client/lib:pooling",
|
||||
"//tensorflow/compiler/xla/client/lib:prng",
|
||||
"//tensorflow/compiler/xla/client/lib:sorting",
|
||||
"//tensorflow/core:framework",
|
||||
@ -153,8 +158,14 @@ tf_kernel_library(
|
||||
"//tensorflow/core/kernels:sparse_to_dense_op",
|
||||
"//tensorflow/core/kernels:stack_ops",
|
||||
"//tensorflow/core/kernels:training_ops",
|
||||
] + if_mkl(
|
||||
[
|
||||
"//tensorflow/core/kernels:mkl_transpose_op",
|
||||
],
|
||||
[
|
||||
"//tensorflow/core/kernels:transpose_op",
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
|
@ -65,6 +65,6 @@ class XlaArgOp : public XlaOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaArgOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes(), XlaArgOp);
|
||||
REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes().CompilationOnly(), XlaArgOp);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -200,12 +200,23 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
bool resource_variable_seen = false;
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
if (ctx->input_type(i) == DT_RESOURCE) {
|
||||
resource_variable_seen = true;
|
||||
} else {
|
||||
OP_REQUIRES(
|
||||
ctx, !resource_variable_seen,
|
||||
errors::FailedPrecondition(
|
||||
"Resource variables and regular inputs cannot be interleaved."));
|
||||
}
|
||||
}
|
||||
|
||||
xla::XlaOp outputs = xla::Conditional(
|
||||
ctx->Input(0), xla::Tuple(b, inputs), *then_result.computation,
|
||||
xla::Tuple(b, inputs), *else_result.computation);
|
||||
// Sets non-variable outputs.
|
||||
for (int i = 0; i < output_types_.size(); ++i) {
|
||||
if (ctx->input_type(i) != DT_RESOURCE) {
|
||||
xla::XlaOp output_handle = xla::GetTupleElement(outputs, i);
|
||||
if (VLOG_IS_ON(2)) {
|
||||
LOG(INFO) << "Setting output " << i;
|
||||
@ -219,7 +230,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
|
||||
}
|
||||
ctx->SetOutput(i, output_handle);
|
||||
}
|
||||
}
|
||||
|
||||
// Updates the values of any resource variables modified by the conditional
|
||||
// bodies.
|
||||
@ -247,6 +257,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
|
||||
}
|
||||
|
||||
REGISTER_XLA_OP(Name("If").AllowResourceTypes(), XlaIfOp);
|
||||
REGISTER_XLA_OP(Name("StatelessIf").AllowResourceTypes(), XlaIfOp);
|
||||
REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/pooling.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
@ -71,59 +72,53 @@ class PoolingOp : public XlaOpKernel {
|
||||
|
||||
int num_dims() const { return num_spatial_dims_ + 2; }
|
||||
|
||||
// Method that builds an initial value to use in reductions.
|
||||
virtual xla::XlaOp InitValue(xla::XlaBuilder* b) = 0;
|
||||
|
||||
// The reduction operation to apply to each window.
|
||||
virtual const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) = 0;
|
||||
|
||||
// A post-processing operation to apply on the outputs of the ReduceWindow.
|
||||
virtual xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
|
||||
const xla::XlaOp& output, DataType dtype,
|
||||
const TensorShape& input_shape) = 0;
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
std::vector<int64> ksize = ksize_;
|
||||
std::vector<int64> stride = stride_;
|
||||
if (ctx->num_inputs() != 1) {
|
||||
protected:
|
||||
xla::StatusOr<std::vector<int64>> GetKernelSize(XlaOpKernelContext* ctx) {
|
||||
if (ctx->num_inputs() == 1) {
|
||||
return ksize_;
|
||||
}
|
||||
const TensorShape ksize_shape = ctx->InputShape(1);
|
||||
// Validate input sizes.
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
|
||||
errors::InvalidArgument("ksize must be a vector, not shape ",
|
||||
ksize_shape.DebugString()));
|
||||
OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(),
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
if (!TensorShapeUtils::IsVector(ksize_shape)) {
|
||||
return errors::InvalidArgument("ksize must be a vector, not shape ",
|
||||
ksize_shape.DebugString());
|
||||
}
|
||||
if (ksize_shape.num_elements() != num_dims()) {
|
||||
return errors::InvalidArgument(
|
||||
"Sliding window ksize field must "
|
||||
"specify ",
|
||||
num_dims(), " dimensions"));
|
||||
ksize.clear();
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize));
|
||||
num_dims(), " dimensions");
|
||||
}
|
||||
std::vector<int64> ksize;
|
||||
auto status = ctx->ConstantInputAsIntVector(1, &ksize);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
return ksize;
|
||||
}
|
||||
|
||||
xla::StatusOr<std::vector<int64>> GetStride(XlaOpKernelContext* ctx) {
|
||||
if (ctx->num_inputs() == 1) {
|
||||
return stride_;
|
||||
}
|
||||
const TensorShape stride_shape = ctx->InputShape(2);
|
||||
// Validate input sizes.
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
|
||||
errors::InvalidArgument("stride must be a vector, not shape ",
|
||||
stride_shape.DebugString()));
|
||||
OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(),
|
||||
errors::InvalidArgument("Sliding window stride field must "
|
||||
"specify ",
|
||||
num_dims(), " dimensions"));
|
||||
stride.clear();
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride));
|
||||
if (!TensorShapeUtils::IsVector(stride_shape)) {
|
||||
return errors::InvalidArgument("stride must be a vector, not shape ",
|
||||
stride_shape.DebugString());
|
||||
}
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
|
||||
errors::InvalidArgument("Input to ", type_string(),
|
||||
" operator must have ", num_dims(),
|
||||
" dimensions"));
|
||||
|
||||
xla::XlaBuilder* const b = ctx->builder();
|
||||
auto input =
|
||||
XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_);
|
||||
auto reduce = xla::ReduceWindow(input, InitValue(b), *Reduction(ctx), ksize,
|
||||
stride, padding_);
|
||||
auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0));
|
||||
ctx->SetOutput(0,
|
||||
PostProcessOutput(ctx, pooled, input_type(0), input_shape));
|
||||
if (stride_shape.num_elements() != num_dims()) {
|
||||
return errors::InvalidArgument(
|
||||
"Sliding window stride field must "
|
||||
"specify ",
|
||||
num_dims(), " dimensions");
|
||||
}
|
||||
std::vector<int64> stride;
|
||||
auto status = ctx->ConstantInputAsIntVector(2, &stride);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
return stride;
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -136,24 +131,48 @@ class PoolingOp : public XlaOpKernel {
|
||||
xla::PrimitiveType xla_reduction_type_;
|
||||
};
|
||||
|
||||
// Converts the tensor data format to the one required by the XLA pooling
|
||||
// library.
|
||||
xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format,
|
||||
int num_spatial_dims) {
|
||||
int num_dims = num_spatial_dims + 2;
|
||||
int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format);
|
||||
int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format);
|
||||
gtl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims);
|
||||
for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
|
||||
spatial_dimensions[spatial_dim] =
|
||||
GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim);
|
||||
}
|
||||
return xla::TensorFormat(/*batch_dimension=*/batch_dimension,
|
||||
/*feature_dimension=*/feature_dimension,
|
||||
/*spatial_dimensions=*/spatial_dimensions);
|
||||
}
|
||||
|
||||
class MaxPoolOp : public PoolingOp {
|
||||
public:
|
||||
MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
|
||||
: PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
|
||||
/*reduction_type=*/ctx->input_type(0)) {}
|
||||
|
||||
xla::XlaOp InitValue(xla::XlaBuilder* b) override {
|
||||
return xla::MinValue(b, xla_reduction_type_);
|
||||
}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto ksize_or_error = GetKernelSize(ctx);
|
||||
OP_REQUIRES_OK(ctx, ksize_or_error.status());
|
||||
std::vector<int64> ksize = ksize_or_error.ValueOrDie();
|
||||
|
||||
const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override {
|
||||
return ctx->GetOrCreateMax(reduction_type_);
|
||||
}
|
||||
auto stride_or_error = GetStride(ctx);
|
||||
OP_REQUIRES_OK(ctx, stride_or_error.status());
|
||||
std::vector<int64> stride = stride_or_error.ValueOrDie();
|
||||
|
||||
xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
|
||||
const xla::XlaOp& output, DataType dtype,
|
||||
const TensorShape& input_shape) override {
|
||||
return output;
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
|
||||
errors::InvalidArgument("Input to ", type_string(),
|
||||
" operator must have ", num_dims(),
|
||||
" dimensions"));
|
||||
|
||||
auto pooling =
|
||||
xla::MaxPool(ctx->Input(0), ksize, stride, padding_,
|
||||
XlaTensorFormat(data_format_, input_shape.dims() - 2));
|
||||
ctx->SetOutput(0, pooling);
|
||||
}
|
||||
};
|
||||
|
||||
@ -180,9 +199,8 @@ class MaxPool3DOp : public MaxPoolOp {
|
||||
};
|
||||
REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
|
||||
|
||||
// Common computation shared between AvgPool and AvgPoolGrad. Divide each
|
||||
// element of an image by the count of elements that contributed to that
|
||||
// element during pooling.
|
||||
// Divide each element of an image by the count of elements that contributed to
|
||||
// that element during pooling.
|
||||
static xla::XlaOp AvgPoolDivideByCount(
|
||||
XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype,
|
||||
const TensorShape& input_shape, xla::Padding padding,
|
||||
@ -241,20 +259,34 @@ class AvgPoolOp : public PoolingOp {
|
||||
/*reduction_type=*/
|
||||
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
|
||||
|
||||
xla::XlaOp InitValue(xla::XlaBuilder* b) override {
|
||||
return xla::Zero(b, xla_reduction_type_);
|
||||
}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto ksize_or_error = GetKernelSize(ctx);
|
||||
OP_REQUIRES_OK(ctx, ksize_or_error.status());
|
||||
std::vector<int64> ksize = ksize_or_error.ValueOrDie();
|
||||
|
||||
const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override {
|
||||
return ctx->GetOrCreateAdd(reduction_type_);
|
||||
}
|
||||
auto stride_or_error = GetStride(ctx);
|
||||
OP_REQUIRES_OK(ctx, stride_or_error.status());
|
||||
std::vector<int64> stride = stride_or_error.ValueOrDie();
|
||||
|
||||
xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
|
||||
const xla::XlaOp& output, DataType dtype,
|
||||
const TensorShape& input_shape) override {
|
||||
return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_,
|
||||
ksize_, stride_, num_spatial_dims_,
|
||||
data_format_);
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
|
||||
errors::InvalidArgument("Input to ", type_string(),
|
||||
" operator must have ", num_dims(),
|
||||
" dimensions"));
|
||||
|
||||
auto xla_data_format =
|
||||
XlaTensorFormat(data_format_, input_shape.dims() - 2);
|
||||
auto spatial_padding = MakeSpatialPadding(
|
||||
input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format);
|
||||
|
||||
// Convert the input to the reduction type.
|
||||
auto converted_input =
|
||||
ConvertElementType(ctx->Input(0), xla_reduction_type_);
|
||||
auto pooling =
|
||||
xla::AvgPool(converted_input, ksize, stride, spatial_padding,
|
||||
xla_data_format, padding_ == xla::Padding::kValid);
|
||||
// Convert the pooling result back to the input type before returning it.
|
||||
ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0)));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -104,7 +104,7 @@ class RetvalOp : public XlaOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("_Retval"), RetvalOp);
|
||||
REGISTER_XLA_OP(Name("_Retval").CompilationOnly(), RetvalOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -38,11 +38,15 @@ class SoftmaxOp : public XlaOpKernel {
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape logits_shape = ctx->InputShape(0);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
|
||||
errors::InvalidArgument("logits must be 2-dimensional"));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(logits_shape),
|
||||
errors::InvalidArgument("logits must have >= 1 dimension, got ",
|
||||
logits_shape.DebugString()));
|
||||
|
||||
const int kBatchDim = 0;
|
||||
const int kClassDim = 1;
|
||||
// Major dimensions are batch dimensions, minor dimension is the class
|
||||
// dimension.
|
||||
std::vector<int64> batch_dims(logits_shape.dims() - 1);
|
||||
std::iota(batch_dims.begin(), batch_dims.end(), 0);
|
||||
const int kClassDim = logits_shape.dims() - 1;
|
||||
|
||||
const DataType type = input_type(0);
|
||||
const xla::PrimitiveType xla_type = ctx->input_xla_type(0);
|
||||
@ -56,7 +60,7 @@ class SoftmaxOp : public XlaOpKernel {
|
||||
xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim});
|
||||
// Subtract the max in batch b from every element in batch b. Broadcasts
|
||||
// along the batch dimension.
|
||||
auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim});
|
||||
auto shifted_logits = xla::Sub(logits, logits_max, batch_dims);
|
||||
auto exp_shifted = xla::Exp(shifted_logits);
|
||||
const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
|
||||
xla::PrimitiveType xla_accumulation_type;
|
||||
@ -71,9 +75,9 @@ class SoftmaxOp : public XlaOpKernel {
|
||||
auto softmax =
|
||||
log_
|
||||
// softmax = shifted_logits - log(sum(exp(shifted_logits)))
|
||||
? xla::Sub(shifted_logits, xla::Log(sum), {kBatchDim})
|
||||
? xla::Sub(shifted_logits, xla::Log(sum), batch_dims)
|
||||
// softmax = exp(shifted_logits) / sum(exp(shifted_logits))
|
||||
: xla::Div(exp_shifted, sum, {kBatchDim});
|
||||
: xla::Div(exp_shifted, sum, batch_dims);
|
||||
ctx->SetOutput(0, softmax);
|
||||
}
|
||||
|
||||
|
@ -301,6 +301,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
||||
}
|
||||
|
||||
REGISTER_XLA_OP(Name("While").AllowResourceTypes(), XlaWhileOp);
|
||||
REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes(), XlaWhileOp);
|
||||
REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -32,6 +32,23 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HostTensorToMutableBorrowingLiteral(
|
||||
Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor->dtype(),
|
||||
host_tensor->shape(), &xla_shape));
|
||||
return HostTensorToMutableBorrowingLiteral(xla_shape, host_tensor, literal);
|
||||
}
|
||||
|
||||
Status HostTensorToMutableBorrowingLiteral(
|
||||
const xla::Shape& xla_shape, Tensor* host_tensor,
|
||||
xla::MutableBorrowingLiteral* literal) {
|
||||
*literal = xla::MutableBorrowingLiteral(
|
||||
static_cast<const char*>(DMAHelper::base(host_tensor)), xla_shape);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HostTensorsToBorrowingLiteralTuple(
|
||||
tensorflow::gtl::ArraySlice<Tensor> host_tensors,
|
||||
xla::BorrowingLiteral* literal) {
|
||||
|
@ -30,6 +30,16 @@ namespace tensorflow {
|
||||
// 'host_tensor'.
|
||||
Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
|
||||
xla::BorrowingLiteral* literal);
|
||||
// Returns a MutableBorrowingLiteral that utilizes the same underlying buffer
|
||||
// owned by 'host_tensor', but is mutable via the xla::Literal methods.
|
||||
Status HostTensorToMutableBorrowingLiteral(
|
||||
Tensor* host_tensor, xla::MutableBorrowingLiteral* literal);
|
||||
// Similar as above, except the literal shape is explicitly provided and used
|
||||
// instead of obtaining it from the 'host_tensor'. The provided literal shape
|
||||
// 'xla_shape' must be compatible with the shape of 'host_tensor'.
|
||||
Status HostTensorToMutableBorrowingLiteral(
|
||||
const xla::Shape& xla_shape, Tensor* host_tensor,
|
||||
xla::MutableBorrowingLiteral* literal);
|
||||
|
||||
// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers
|
||||
// owned by 'host_tensors'.
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
|
||||
#include <queue>
|
||||
#include <random>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
|
||||
@ -297,4 +298,29 @@ void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype,
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
uint32 InitialRandomSeed() {
|
||||
// Support plumbing the TF seed through to XLA is being worked on.
|
||||
// If a user wants deterministic behavior, their best option
|
||||
// is to start with a known checkpoint. This also handles issues when
|
||||
// multiple random calls can be invoked in any order by TF executor.
|
||||
// Another option is to use stateless random ops. They have much cleaner
|
||||
// semantics.
|
||||
// If a user really wants to set a deterministic seed for XLA-based
|
||||
// devices, this is the place to do it.
|
||||
std::random_device rd;
|
||||
// Make the starting value odd.
|
||||
return rd() | 1;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
uint32 GetXLARandomSeed() {
|
||||
// We initialize counter with an odd number and increment it by two
|
||||
// everytime. This ensures that it will never be zero, even
|
||||
// after an overflow. When seeded with zero, some XLA backends
|
||||
// can return all zeros instead of random numbers.
|
||||
static std::atomic<uint32> counter(InitialRandomSeed());
|
||||
return counter.fetch_add(2);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -56,6 +56,9 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges);
|
||||
void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype,
|
||||
KernelDef* kdef);
|
||||
|
||||
// Returns the next random seed to use for seeding xla rng.
|
||||
uint32 GetXLARandomSeed();
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
|
||||
|
@ -14,7 +14,6 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
|
||||
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
@ -22,61 +21,42 @@ namespace tensorflow {
|
||||
|
||||
XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
|
||||
AllocMode alloc_mode)
|
||||
: raw_function_(static_data.raw_function),
|
||||
result_index_(static_data.result_index),
|
||||
args_(new void*[static_data.num_args]),
|
||||
temps_(new void*[static_data.num_temps]),
|
||||
arg_index_to_temp_index_(new int32[static_data.num_args]),
|
||||
num_args_(static_data.num_args),
|
||||
arg_names_(static_data.arg_names),
|
||||
result_names_(static_data.result_names),
|
||||
program_shape_(static_data.program_shape),
|
||||
hlo_profile_printer_data_(static_data.hlo_profile_printer_data) {
|
||||
: raw_function_(static_data.raw_function_),
|
||||
result_index_(static_data.result_index_),
|
||||
buffer_table_(new void*[static_data.num_buffers_]),
|
||||
buffer_infos_(static_data.buffer_infos_),
|
||||
arg_index_table_(static_data.arg_index_table_),
|
||||
num_args_(static_data.num_args_),
|
||||
arg_names_(static_data.arg_names_),
|
||||
result_names_(static_data.result_names_),
|
||||
program_shape_(static_data.program_shape_),
|
||||
hlo_profile_printer_data_(static_data.hlo_profile_printer_data_) {
|
||||
bool allocate_entry_params =
|
||||
alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS;
|
||||
// Allocate arg and temp buffers.
|
||||
if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) {
|
||||
alloc_args_ = cpu_function_runtime::MallocContiguousBuffers(
|
||||
static_data.arg_sizes, static_data.num_args, args_,
|
||||
/*annotate_initialized=*/false);
|
||||
}
|
||||
alloc_temps_ = cpu_function_runtime::MallocContiguousBuffers(
|
||||
static_data.temp_sizes, static_data.num_temps, temps_,
|
||||
alloc_buffer_table_ = cpu_function_runtime::MallocContiguousBuffers(
|
||||
static_data.buffer_infos_, static_data.num_buffers_,
|
||||
/*allocate_entry_params=*/allocate_entry_params, buffer_table_,
|
||||
/*annotate_initialized=*/true);
|
||||
|
||||
for (int i = 0; i < static_data.num_temps; i++) {
|
||||
if (static_data.temp_sizes[i] < -1) {
|
||||
int32 param_number = -(static_data.temp_sizes[i] + 2);
|
||||
arg_index_to_temp_index_[param_number] = i;
|
||||
}
|
||||
}
|
||||
|
||||
// If Hlo profiling is enabled the generated code expects an appropriately
|
||||
// sized buffer to be passed in as the last argument. If Hlo profiling is
|
||||
// disabled the last function argument is still present in the function
|
||||
// signature, but it is ignored by the generated code and we pass in null for
|
||||
// it.
|
||||
if (hlo_profiling_enabled()) {
|
||||
profile_counters_ = new int64[static_data.profile_counters_size]();
|
||||
profile_counters_ = new int64[static_data.profile_counters_size_]();
|
||||
}
|
||||
}
|
||||
|
||||
bool XlaCompiledCpuFunction::Run() {
|
||||
// Propagate pointers to the argument buffers into the temps array. Code
|
||||
// generated by XLA discovers the incoming argument pointers from the temps
|
||||
// array.
|
||||
for (int32 i = 0; i < num_args_; i++) {
|
||||
temps_[arg_index_to_temp_index_[i]] = args_[i];
|
||||
}
|
||||
raw_function_(temps_[result_index_], &run_options_, nullptr, temps_,
|
||||
profile_counters_);
|
||||
raw_function_(buffer_table_[result_index_], &run_options_, nullptr,
|
||||
buffer_table_, profile_counters_);
|
||||
return true;
|
||||
}
|
||||
|
||||
XlaCompiledCpuFunction::~XlaCompiledCpuFunction() {
|
||||
cpu_function_runtime::FreeContiguous(alloc_args_);
|
||||
cpu_function_runtime::FreeContiguous(alloc_temps_);
|
||||
delete[] args_;
|
||||
delete[] temps_;
|
||||
delete[] arg_index_to_temp_index_;
|
||||
cpu_function_runtime::FreeContiguous(alloc_buffer_table_);
|
||||
delete[] buffer_table_;
|
||||
delete[] profile_counters_;
|
||||
}
|
||||
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <cassert>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -56,46 +57,85 @@ class XlaCompiledCpuFunction {
|
||||
// StaticData represents the state necessary to run an XLA-compiled
|
||||
// function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for
|
||||
// AOT this is backed by data compiled into the object file.
|
||||
struct StaticData {
|
||||
//
|
||||
// The contents of StaticData are XLA-internal implementation details and
|
||||
// should not be relied on by clients.
|
||||
//
|
||||
// TODO(sanjoy): Come up with a cleaner way to express the contraint we want
|
||||
// here: generated XlaCompiledCpuFunction subclasses should be able to create
|
||||
// instances of StaticData but only XlaCompiledCpuFunction should be able to
|
||||
// read from StaticData instances.
|
||||
class StaticData {
|
||||
public:
|
||||
void set_raw_function(RawFunction raw_function) {
|
||||
raw_function_ = raw_function;
|
||||
}
|
||||
void set_buffer_infos(
|
||||
const cpu_function_runtime::BufferInfo* buffer_infos) {
|
||||
buffer_infos_ = buffer_infos;
|
||||
}
|
||||
void set_num_buffers(size_t num_buffers) { num_buffers_ = num_buffers; }
|
||||
void set_arg_index_table(const int32* arg_index_table) {
|
||||
arg_index_table_ = arg_index_table;
|
||||
}
|
||||
void set_num_args(int64 num_args) { num_args_ = num_args; }
|
||||
void set_result_index(size_t result_index) { result_index_ = result_index; }
|
||||
void set_arg_names(const char** arg_names) { arg_names_ = arg_names; }
|
||||
void set_result_names(const char** result_names) {
|
||||
result_names_ = result_names;
|
||||
}
|
||||
void set_program_shape(const xla::ProgramShape* program_shape) {
|
||||
program_shape_ = program_shape;
|
||||
}
|
||||
const xla::HloProfilePrinterData* hlo_profile_printer_data() const {
|
||||
return hlo_profile_printer_data_;
|
||||
}
|
||||
void set_hlo_profile_printer_data(
|
||||
const xla::HloProfilePrinterData* hlo_profile_printer_data) {
|
||||
hlo_profile_printer_data_ = hlo_profile_printer_data;
|
||||
}
|
||||
void set_profile_counters_size(int64 profile_counters_size) {
|
||||
profile_counters_size_ = profile_counters_size;
|
||||
}
|
||||
|
||||
private:
|
||||
// The raw function to call.
|
||||
RawFunction raw_function;
|
||||
RawFunction raw_function_;
|
||||
|
||||
// Cardinality and size of arg buffers.
|
||||
const intptr_t* arg_sizes = nullptr;
|
||||
size_t num_args = 0;
|
||||
// Contains information about the buffers used by the XLA computation.
|
||||
const cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr;
|
||||
size_t num_buffers_ = 0;
|
||||
|
||||
// Cardinality and size of temp buffers.
|
||||
//
|
||||
// If temp_sizes[i] >= 0 then the i'th temp is a regular temporary buffer.
|
||||
//
|
||||
// If temp_sizes[i] == -1 then the i'th temp is a constant buffer. The
|
||||
// corresponding entry in the temp buffer array needs to be set to null.
|
||||
//
|
||||
// If temp_sizes[i] < -1 then the i'th temp is the entry parameter
|
||||
// -(temp_sizes[i] + 2).
|
||||
const intptr_t* temp_sizes = nullptr;
|
||||
size_t num_temps = 0;
|
||||
// Entry parameter i is described by
|
||||
// buffer_infos[arg_index_table[i]].
|
||||
const int32* arg_index_table_ = nullptr;
|
||||
|
||||
// There are num_args entry parameters.
|
||||
int64 num_args_ = 0;
|
||||
|
||||
// The 0-based index of the result tuple, in the temp buffers.
|
||||
size_t result_index = 0;
|
||||
size_t result_index_ = 0;
|
||||
|
||||
// [Optional] Arrays of arg and result names. These are arrays of C-style
|
||||
// strings, where the array is terminated by nullptr.
|
||||
const char** arg_names = nullptr;
|
||||
const char** result_names = nullptr;
|
||||
const char** arg_names_ = nullptr;
|
||||
const char** result_names_ = nullptr;
|
||||
|
||||
// [Optional] Arg and result shapes.
|
||||
const xla::ProgramShape* program_shape = nullptr;
|
||||
const xla::ProgramShape* program_shape_ = nullptr;
|
||||
|
||||
// [Optional] Profile printer data. Null if profiling is disabled.
|
||||
const xla::HloProfilePrinterData* hlo_profile_printer_data = nullptr;
|
||||
const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
|
||||
|
||||
// [Optional] The number of profile counters expected in the profile counter
|
||||
// buffer by the generated code and hlo_profile_printer. 0 if profiling is
|
||||
// disabled. This information is already present in
|
||||
// hlo_profile_printer_data but xla::HloProfilePrinterData is forward
|
||||
// declared so we don't have access to that information here.
|
||||
int64 profile_counters_size = 0;
|
||||
int64 profile_counters_size_ = 0;
|
||||
|
||||
// Only XlaCompiledCpuFunction is allowed to read the above fields.
|
||||
friend class XlaCompiledCpuFunction;
|
||||
};
|
||||
|
||||
// AllocMode controls the buffer allocation mode.
|
||||
@ -135,14 +175,25 @@ class XlaCompiledCpuFunction {
|
||||
// ------------------------------
|
||||
// Arg methods for managing input buffers. Buffers are in row-major order.
|
||||
|
||||
// Returns the underlying array of argument buffers, where args()[I] is the
|
||||
// buffer for the positional argument at index I.
|
||||
void** args() { return args_; }
|
||||
const void* const* args() const { return args_; }
|
||||
|
||||
// Returns the buffer for the positional argument at the given `index`.
|
||||
void* arg_data(size_t index) { return args_[index]; }
|
||||
const void* arg_data(size_t index) const { return args_[index]; }
|
||||
void* arg_data(size_t index) {
|
||||
return buffer_table_[arg_index_table_[index]];
|
||||
}
|
||||
const void* arg_data(size_t index) const {
|
||||
return buffer_table_[arg_index_table_[index]];
|
||||
}
|
||||
|
||||
int num_args() const { return num_args_; }
|
||||
|
||||
// Returns the size of entry parameter `idx`.
|
||||
//
|
||||
// There is a static version of this method on tfcompile generated subclasses
|
||||
// of XlaCompiledCpuFunction, but try to prefer this when possible since it
|
||||
// works both for XlaJitCompiledCpuFunction and AOT compiled subclasses.
|
||||
int arg_size(int idx) const {
|
||||
assert(idx < num_args());
|
||||
return buffer_infos_[arg_index_table_[idx]].size();
|
||||
}
|
||||
|
||||
// Sets the buffer for the positional argument at the given `index` to `data`.
|
||||
// Must be called before Run to have an effect. May be called under any
|
||||
@ -155,7 +206,9 @@ class XlaCompiledCpuFunction {
|
||||
//
|
||||
// Aliasing of argument and result buffers is not allowed, and results in
|
||||
// undefined behavior.
|
||||
void set_arg_data(size_t index, void* data) { args_[index] = data; }
|
||||
void set_arg_data(size_t index, void* data) {
|
||||
buffer_table_[arg_index_table_[index]] = data;
|
||||
}
|
||||
|
||||
// ------------------------------
|
||||
// Result methods for managing output buffers. Buffers are in row-major order.
|
||||
@ -165,9 +218,9 @@ class XlaCompiledCpuFunction {
|
||||
|
||||
// Returns the underlying array of result buffers, where results()[I] is the
|
||||
// buffer for the positional result at index I.
|
||||
void** results() { return static_cast<void**>(temps_[result_index_]); }
|
||||
void** results() { return static_cast<void**>(buffer_table_[result_index_]); }
|
||||
const void* const* results() const {
|
||||
return static_cast<const void* const*>(temps_[result_index_]);
|
||||
return static_cast<const void* const*>(buffer_table_[result_index_]);
|
||||
}
|
||||
|
||||
// Profile counters for this XLA computation.
|
||||
@ -225,25 +278,28 @@ class XlaCompiledCpuFunction {
|
||||
const RawFunction raw_function_;
|
||||
const size_t result_index_;
|
||||
|
||||
// Arrays of argument and temp buffers; entries in args_ may be overwritten by
|
||||
// the user.
|
||||
void** args_ = nullptr;
|
||||
void** temps_ = nullptr;
|
||||
// Array containing pointers to argument and temp buffers (slots corresponding
|
||||
// to constant and on-stack buffers are null).
|
||||
void** const buffer_table_;
|
||||
|
||||
// Argument i needs to be placed in temps_[arg_index_to_temp_index_[i]] for
|
||||
// XLA generated code to be able to find it.
|
||||
// Describes the buffers used by the XLA computation.
|
||||
const cpu_function_runtime::BufferInfo* const buffer_infos_;
|
||||
|
||||
// Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]]
|
||||
// for XLA generated code to be able to find it.
|
||||
//
|
||||
// For now we need to keep around the args_ array because there is code that
|
||||
// depends on args() returning a void**. However, in the future we may remove
|
||||
// args_ in favor of using temps_ as the sole storage for the arguments.
|
||||
int32* arg_index_to_temp_index_;
|
||||
// args_ in favor of using buffer_table_ as the sole storage for the
|
||||
// arguments.
|
||||
const int32* const arg_index_table_;
|
||||
|
||||
// The number of incoming arguments.
|
||||
int32 num_args_;
|
||||
const int32 num_args_;
|
||||
|
||||
// Backing memory for individual arg and temp buffers.
|
||||
void* alloc_args_ = nullptr;
|
||||
void* alloc_temps_ = nullptr;
|
||||
// Backing memory for buffer_table_ and args_, the latter depending on
|
||||
// AllocMode.
|
||||
void* alloc_buffer_table_ = nullptr;
|
||||
|
||||
// Backing memory for profiling counters.
|
||||
int64* profile_counters_ = nullptr;
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -35,45 +36,6 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns a vector of positional argument buffer sizes.
|
||||
xla::StatusOr<std::vector<intptr_t>> ComputeArgSizes(
|
||||
const xla::ProgramShape& program_shape) {
|
||||
std::vector<intptr_t> arg_sizes;
|
||||
const size_t num_args = program_shape.parameters_size();
|
||||
arg_sizes.reserve(num_args);
|
||||
for (int i = 0; i < num_args; ++i) {
|
||||
const xla::Shape& arg_shape = program_shape.parameters(i);
|
||||
constexpr size_t kPointerSize = sizeof(void*);
|
||||
arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize));
|
||||
}
|
||||
return std::move(arg_sizes);
|
||||
}
|
||||
|
||||
// Returns a vector of positional temporary buffer sizes.
|
||||
xla::StatusOr<std::vector<intptr_t>> ComputeTempSizes(
|
||||
const xla::BufferAssignment& buffer_assignment) {
|
||||
const std::vector<xla::BufferAllocation>& allocations =
|
||||
buffer_assignment.Allocations();
|
||||
std::vector<intptr_t> temp_sizes;
|
||||
temp_sizes.reserve(allocations.size());
|
||||
for (const xla::BufferAllocation& allocation : allocations) {
|
||||
if (allocation.is_constant() || allocation.is_thread_local()) {
|
||||
// Constants are lowered to globals. Thread locals are lowered to
|
||||
// allocas.
|
||||
temp_sizes.push_back(-1);
|
||||
} else if (allocation.is_entry_computation_parameter()) {
|
||||
// Entry computation parameters need some preprocessing in
|
||||
// XlaCompiledCpuFunction::Run. See the comment on
|
||||
// XlaCompiledCpuFunction::StaticData::temp_sizes.
|
||||
temp_sizes.push_back(-allocation.parameter_number() - 2);
|
||||
} else {
|
||||
temp_sizes.push_back(allocation.size());
|
||||
}
|
||||
}
|
||||
return std::move(temp_sizes);
|
||||
}
|
||||
|
||||
// Returns the index of the result in the temp buffers.
|
||||
xla::StatusOr<size_t> ComputeResultIndex(
|
||||
const xla::BufferAssignment& buffer_assignment) {
|
||||
@ -157,11 +119,11 @@ XlaJitCompiledCpuFunction::Compile(
|
||||
const xla::BufferAssignment& buffer_assignment =
|
||||
cpu_executable->buffer_assignment();
|
||||
|
||||
// Compute buffer sizes and the result index, needed to run the raw function.
|
||||
TF_ASSIGN_OR_RETURN(std::vector<intptr_t> arg_sizes,
|
||||
ComputeArgSizes(*program_shape));
|
||||
TF_ASSIGN_OR_RETURN(std::vector<intptr_t> temp_sizes,
|
||||
ComputeTempSizes(buffer_assignment));
|
||||
// Compute buffer infos and the result index, needed to run the raw function.
|
||||
std::vector<cpu_function_runtime::BufferInfo> buffer_infos =
|
||||
xla::cpu::CreateBufferInfosFromBufferAssignment(buffer_assignment);
|
||||
std::vector<int32> arg_index_table =
|
||||
xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos);
|
||||
TF_ASSIGN_OR_RETURN(size_t result_index,
|
||||
ComputeResultIndex(buffer_assignment));
|
||||
|
||||
@ -169,28 +131,28 @@ XlaJitCompiledCpuFunction::Compile(
|
||||
new XlaJitCompiledCpuFunction);
|
||||
XlaJitCompiledCpuFunction* jit = jit_unique_ptr.get();
|
||||
jit->executable_ = std::move(executable);
|
||||
jit->arg_sizes_ = std::move(arg_sizes);
|
||||
jit->temp_sizes_ = std::move(temp_sizes);
|
||||
jit->buffer_infos_ = std::move(buffer_infos);
|
||||
jit->arg_index_table_ = std::move(arg_index_table);
|
||||
jit->program_shape_ = std::move(program_shape);
|
||||
jit->static_data_.raw_function = std::move(raw_function);
|
||||
jit->static_data_.arg_sizes = jit->arg_sizes_.data();
|
||||
jit->static_data_.num_args = jit->arg_sizes_.size();
|
||||
jit->static_data_.temp_sizes = jit->temp_sizes_.data();
|
||||
jit->static_data_.num_temps = jit->temp_sizes_.size();
|
||||
jit->static_data_.result_index = result_index;
|
||||
jit->static_data_.set_raw_function(raw_function);
|
||||
jit->static_data_.set_buffer_infos(jit->buffer_infos_.data());
|
||||
jit->static_data_.set_num_buffers(jit->buffer_infos_.size());
|
||||
jit->static_data_.set_arg_index_table(jit->arg_index_table_.data());
|
||||
jit->static_data_.set_num_args(jit->arg_index_table_.size());
|
||||
jit->static_data_.set_result_index(result_index);
|
||||
// Optional metadata is collected and set below.
|
||||
CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_);
|
||||
CollectNames(config.fetch(), &jit->nonempty_result_names_,
|
||||
&jit->result_names_);
|
||||
jit->static_data_.arg_names = jit->arg_names_.data();
|
||||
jit->static_data_.result_names = jit->result_names_.data();
|
||||
jit->static_data_.program_shape = jit->program_shape_.get();
|
||||
jit->static_data_.set_arg_names(jit->arg_names_.data());
|
||||
jit->static_data_.set_result_names(jit->result_names_.data());
|
||||
jit->static_data_.set_program_shape(jit->program_shape_.get());
|
||||
|
||||
if (cpu_executable->hlo_profiling_enabled()) {
|
||||
jit->static_data_.hlo_profile_printer_data =
|
||||
&cpu_executable->hlo_profile_printer_data();
|
||||
jit->static_data_.profile_counters_size =
|
||||
cpu_executable->hlo_profile_printer_data().profile_counters_size();
|
||||
jit->static_data_.set_hlo_profile_printer_data(
|
||||
&cpu_executable->hlo_profile_printer_data());
|
||||
jit->static_data_.set_profile_counters_size(
|
||||
cpu_executable->hlo_profile_printer_data().profile_counters_size());
|
||||
}
|
||||
|
||||
return std::move(jit_unique_ptr);
|
||||
|
@ -66,9 +66,11 @@ class XlaJitCompiledCpuFunction {
|
||||
// The static data is backed by the rest of the state in this class.
|
||||
XlaCompiledCpuFunction::StaticData static_data_;
|
||||
|
||||
// The backing arrays of arg and temp buffer sizes.
|
||||
std::vector<intptr_t> arg_sizes_;
|
||||
std::vector<intptr_t> temp_sizes_;
|
||||
// The backing array for buffer infos.
|
||||
std::vector<cpu_function_runtime::BufferInfo> buffer_infos_;
|
||||
|
||||
// The backing array for the arg index table.
|
||||
std::vector<int32> arg_index_table_;
|
||||
|
||||
// The backing arrays of arg and result names. We hold the actual strings in
|
||||
// nonempty_*_names_, and hold arrays of pointers in *_names_ for the static
|
||||
|
@ -409,7 +409,7 @@ class Array {
|
||||
|
||||
// Returns the total number of elements in the array.
|
||||
int64 num_elements() const {
|
||||
return std::accumulate(sizes_.begin(), sizes_.end(), 1,
|
||||
return std::accumulate(sizes_.begin(), sizes_.end(), 1LL,
|
||||
std::multiplies<int64>());
|
||||
}
|
||||
|
||||
|
@ -121,6 +121,30 @@ xla_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pooling",
|
||||
srcs = ["pooling.cc"],
|
||||
hdrs = ["pooling.h"],
|
||||
deps = [
|
||||
":arithmetic",
|
||||
":constants",
|
||||
"//tensorflow/compiler/tf2xla/lib:util",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "pooling_test",
|
||||
srcs = ["pooling_test.cc"],
|
||||
deps = [
|
||||
":pooling",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "prng",
|
||||
srcs = ["prng.cc"],
|
||||
@ -144,7 +168,7 @@ cc_library(
|
||||
":numeric",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
],
|
||||
)
|
||||
|
||||
@ -161,7 +185,7 @@ xla_test(
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
|
183
tensorflow/compiler/xla/client/lib/pooling.cc
Normal file
183
tensorflow/compiler/xla/client/lib/pooling.cc
Normal file
@ -0,0 +1,183 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/pooling.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
// Common computation shared between AvgPool and AvgPoolGrad. Divide each
|
||||
// element of an image by the count of elements that contributed to that
|
||||
// element during pooling.
|
||||
XlaOp AvgPoolDivideByCountWithGeneralPadding(
|
||||
XlaOp sums, PrimitiveType dtype,
|
||||
tensorflow::gtl::ArraySlice<int64> input_shape,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> spatial_padding,
|
||||
tensorflow::gtl::ArraySlice<int64> ksize,
|
||||
tensorflow::gtl::ArraySlice<int64> stride,
|
||||
const TensorFormat& data_format) {
|
||||
// The padding shouldn't be included in the counts. We use another
|
||||
// ReduceWindow to find the right counts.
|
||||
const int num_spatial_dims = spatial_padding.size();
|
||||
|
||||
std::vector<int64> input_dim_sizes(num_spatial_dims);
|
||||
std::vector<int64> window_dims(num_spatial_dims);
|
||||
std::vector<int64> window_ksize(num_spatial_dims);
|
||||
std::vector<int64> window_stride(num_spatial_dims);
|
||||
CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
|
||||
<< "Invalid number of spatial dimentions in data format specification";
|
||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||
int dim = data_format.spatial_dimension(i);
|
||||
input_dim_sizes[i] = input_shape[dim];
|
||||
window_dims[i] = dim;
|
||||
window_ksize[i] = ksize[dim];
|
||||
window_stride[i] = stride[dim];
|
||||
}
|
||||
|
||||
XlaBuilder* b = sums.builder();
|
||||
// Build a matrix of all 1s, with the same width/height as the input.
|
||||
auto ones = Broadcast(One(b, dtype), input_dim_sizes);
|
||||
PaddingConfig padding_config;
|
||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||
auto dims = padding_config.add_dimensions();
|
||||
dims->set_edge_padding_low(spatial_padding[i].first);
|
||||
dims->set_edge_padding_high(spatial_padding[i].second);
|
||||
}
|
||||
auto zero = Zero(b, dtype);
|
||||
auto padded_ones = Pad(ones, zero, padding_config);
|
||||
|
||||
// Perform a ReduceWindow with the same window size, strides, and padding
|
||||
// to count the number of contributions to each result element.
|
||||
auto counts =
|
||||
ReduceWindow(padded_ones, zero, CreateScalarAddComputation(dtype, b),
|
||||
window_ksize, window_stride, Padding::kValid);
|
||||
|
||||
return Div(sums, counts, window_dims);
|
||||
}
|
||||
|
||||
// Sums all elements in the window specified by 'kernel_size' and 'stride'.
|
||||
XlaOp ComputeSums(XlaOp operand, XlaOp init_value,
|
||||
tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride,
|
||||
const TensorFormat& data_format) {
|
||||
XlaBuilder* b = operand.builder();
|
||||
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
|
||||
TF_ASSIGN_OR_RETURN(Shape init_shape, b->GetShape(init_value));
|
||||
PrimitiveType accumulation_type = init_shape.element_type();
|
||||
auto add_computation = CreateScalarAddComputation(accumulation_type, b);
|
||||
return ReduceWindow(operand, init_value, add_computation, kernel_size,
|
||||
stride, Padding::kValid);
|
||||
});
|
||||
}
|
||||
|
||||
// Creates a padding configuration out of spatial padding values.
|
||||
PaddingConfig MakeSpatialPaddingConfig(
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> spatial_padding,
|
||||
tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride,
|
||||
const TensorFormat& data_format) {
|
||||
const int num_spatial_dims = kernel_size.size() - 2;
|
||||
PaddingConfig padding_config;
|
||||
for (int i = 0; i < 2 + num_spatial_dims; ++i) {
|
||||
padding_config.add_dimensions();
|
||||
}
|
||||
CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
|
||||
<< "Invalid number of spatial dimentions in data format specification";
|
||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||
int dim = data_format.spatial_dimension(i);
|
||||
auto padding_dimension = padding_config.mutable_dimensions(dim);
|
||||
padding_dimension->set_edge_padding_low(spatial_padding[i].first);
|
||||
padding_dimension->set_edge_padding_high(spatial_padding[i].second);
|
||||
}
|
||||
return padding_config;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
|
||||
const TensorFormat& data_format) {
|
||||
XlaBuilder* b = operand.builder();
|
||||
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
|
||||
PrimitiveType dtype = operand_shape.element_type();
|
||||
auto max_computation = CreateScalarMaxComputation(dtype, b);
|
||||
auto init_value = MinValue(b, dtype);
|
||||
return ReduceWindow(operand, init_value, max_computation, kernel_size,
|
||||
stride, padding);
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
|
||||
const TensorFormat& data_format,
|
||||
const bool counts_include_padding) {
|
||||
XlaBuilder* b = operand.builder();
|
||||
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
|
||||
PrimitiveType dtype = operand_shape.element_type();
|
||||
auto init_value = Zero(b, dtype);
|
||||
std::vector<int64> input_size(operand_shape.dimensions().begin(),
|
||||
operand_shape.dimensions().end());
|
||||
auto padding_config =
|
||||
MakeSpatialPaddingConfig(padding, kernel_size, stride, data_format);
|
||||
auto padded_operand = Pad(operand, Zero(b, dtype), padding_config);
|
||||
auto pooled = ComputeSums(padded_operand, init_value, kernel_size, stride,
|
||||
data_format);
|
||||
if (counts_include_padding) {
|
||||
// If counts include padding, all windows have the same number of elements
|
||||
// contributing to each average. Divide by the window size everywhere to
|
||||
// get the average.
|
||||
int64 window_size =
|
||||
std::accumulate(kernel_size.begin(), kernel_size.end(), 1,
|
||||
[](int64 x, int64 y) { return x * y; });
|
||||
|
||||
auto divisor = ConstantR0WithType(b, dtype, window_size);
|
||||
return pooled / divisor;
|
||||
} else {
|
||||
return AvgPoolDivideByCountWithGeneralPadding(
|
||||
pooled, dtype, input_size, padding, kernel_size, stride, data_format);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64, int64>> MakeSpatialPadding(
|
||||
tensorflow::gtl::ArraySlice<int64> input_size,
|
||||
tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
|
||||
const TensorFormat& data_format) {
|
||||
const int num_spatial_dims = kernel_size.size() - 2;
|
||||
std::vector<int64> input_spatial_dimensions;
|
||||
std::vector<int64> kernel_size_spatial_dimensions;
|
||||
std::vector<int64> stride_spatial_dimensions;
|
||||
CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
|
||||
<< "Invalid number of spatial dimentions in data format specification";
|
||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||
int dim = data_format.spatial_dimension(i);
|
||||
input_spatial_dimensions.push_back(input_size[dim]);
|
||||
kernel_size_spatial_dimensions.push_back(kernel_size[dim]);
|
||||
stride_spatial_dimensions.push_back(stride[dim]);
|
||||
}
|
||||
return MakePadding(input_spatial_dimensions, kernel_size_spatial_dimensions,
|
||||
stride_spatial_dimensions, padding);
|
||||
}
|
||||
|
||||
} // namespace xla
|
73
tensorflow/compiler/xla/client/lib/pooling.h
Normal file
73
tensorflow/compiler/xla/client/lib/pooling.h
Normal file
@ -0,0 +1,73 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Tensor format for reduce window operations.
|
||||
class TensorFormat {
|
||||
public:
|
||||
TensorFormat(int batch_dimension, int feature_dimension,
|
||||
tensorflow::gtl::ArraySlice<int64> spatial_dimensions)
|
||||
: batch_dimension_(batch_dimension),
|
||||
feature_dimension_(feature_dimension),
|
||||
spatial_dimensions_(spatial_dimensions.begin(),
|
||||
spatial_dimensions.end()) {}
|
||||
|
||||
int batch_dimension() const { return batch_dimension_; }
|
||||
|
||||
int feature_dimension() const { return feature_dimension_; }
|
||||
|
||||
int spatial_dimension(int dim) const { return spatial_dimensions_[dim]; }
|
||||
|
||||
int num_spatial_dims() const { return spatial_dimensions_.size(); }
|
||||
|
||||
private:
|
||||
// The number of the dimension that represents the batch.
|
||||
int batch_dimension_;
|
||||
// The number of the dimension that represents the features.
|
||||
int feature_dimension_;
|
||||
// The dimension numbers for the spatial dimensions.
|
||||
tensorflow::gtl::InlinedVector<int, 4> spatial_dimensions_;
|
||||
};
|
||||
|
||||
// Computes the max pool of 'operand'.
|
||||
XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
|
||||
const TensorFormat& data_format);
|
||||
|
||||
// Computes the average pool of 'operand'.
|
||||
XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
|
||||
const TensorFormat& data_format,
|
||||
const bool counts_include_padding);
|
||||
|
||||
// Returns the list of low and high padding elements in each spatial dimension
|
||||
// for the given 'padding' specification.
|
||||
std::vector<std::pair<int64, int64>> MakeSpatialPadding(
|
||||
tensorflow::gtl::ArraySlice<int64> input_size,
|
||||
tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
|
||||
const TensorFormat& data_format);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_
|
185
tensorflow/compiler/xla/client/lib/pooling_test.cc
Normal file
185
tensorflow/compiler/xla/client/lib/pooling_test.cc
Normal file
@ -0,0 +1,185 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/pooling.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
TensorFormat MakeNCHWFormat(int num_spatial_dims) {
|
||||
tensorflow::gtl::InlinedVector<int64, 4> spatial_dimensions;
|
||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||
spatial_dimensions.push_back(i + 2);
|
||||
}
|
||||
return TensorFormat(/*batch_dimension=*/0, /*feature_dimension=*/1,
|
||||
/*spatial_dimensions=*/spatial_dimensions);
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64, int64>> MakeGeneralPadding(
|
||||
XlaOp input, tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
|
||||
const xla::TensorFormat& data_format) {
|
||||
XlaBuilder* b = input.builder();
|
||||
Shape operand_shape = b->GetShape(input).ValueOrDie();
|
||||
std::vector<int64> input_size(operand_shape.dimensions().begin(),
|
||||
operand_shape.dimensions().end());
|
||||
return MakeSpatialPadding(input_size, kernel_size, stride, padding,
|
||||
data_format);
|
||||
}
|
||||
|
||||
// Add singleton batch and feature dimensions to spatial dimensions, according
|
||||
// to 'data_format' specification.
|
||||
std::vector<int64> ExpandWithBatchAndFeatureDimensions(
|
||||
tensorflow::gtl::ArraySlice<int64> spatial_dim_sizes,
|
||||
const xla::TensorFormat& data_format) {
|
||||
const int num_spatial_dims = spatial_dim_sizes.size();
|
||||
std::vector<int64> tensor_sizes(num_spatial_dims + 2, 1);
|
||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||
int dim = data_format.spatial_dimension(i);
|
||||
tensor_sizes[dim] = spatial_dim_sizes[i];
|
||||
}
|
||||
return tensor_sizes;
|
||||
}
|
||||
|
||||
class PoolingTest : public ClientLibraryTestBase {
|
||||
public:
|
||||
ErrorSpec error_spec_{0.0001};
|
||||
};
|
||||
|
||||
XLA_TEST_F(PoolingTest, MaxPool2D) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
XlaOp input = ConstantR4FromArray4D<float>(
|
||||
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
|
||||
auto data_format = MakeNCHWFormat(2);
|
||||
auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
|
||||
auto stride = kernel_size;
|
||||
MaxPool(input, kernel_size, stride, Padding::kValid, data_format);
|
||||
|
||||
ComputeAndCompareR4<float>(&builder, {{{{5, 4}}}}, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(PoolingTest, MaxPool2DWithPadding) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
XlaOp input = ConstantR4FromArray4D<float>(
|
||||
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
|
||||
auto data_format = MakeNCHWFormat(2);
|
||||
auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
|
||||
auto stride = kernel_size;
|
||||
MaxPool(input, kernel_size, stride, Padding::kSame, data_format);
|
||||
|
||||
ComputeAndCompareR4<float>(&builder, {{{{5, 4, 5}}}}, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(PoolingTest, MaxPool2DWithPaddingAndStride) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
XlaOp input = ConstantR4FromArray4D<float>(
|
||||
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
|
||||
auto data_format = MakeNCHWFormat(2);
|
||||
auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
|
||||
auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format);
|
||||
MaxPool(input, kernel_size, stride, Padding::kSame, data_format);
|
||||
|
||||
ComputeAndCompareR4<float>(&builder, {{{{5, 4, 4, 5, 5}, {5, 4, 3, 2, 1}}}},
|
||||
{}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(PoolingTest, AvgPool2D) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
XlaOp input = ConstantR4FromArray4D<float>(
|
||||
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
|
||||
auto data_format = MakeNCHWFormat(2);
|
||||
auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
|
||||
auto stride = kernel_size;
|
||||
auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kValid,
|
||||
data_format);
|
||||
AvgPool(input, kernel_size, stride, padding, data_format,
|
||||
/*counts_include_padding=*/true);
|
||||
|
||||
ComputeAndCompareR4<float>(&builder, {{{{3, 3}}}}, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(PoolingTest, AvgPool2DWithPadding) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
XlaOp input = ConstantR4FromArray4D<float>(
|
||||
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
|
||||
auto data_format = MakeNCHWFormat(2);
|
||||
auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
|
||||
auto stride = kernel_size;
|
||||
auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kSame,
|
||||
data_format);
|
||||
AvgPool(input, kernel_size, stride, padding, data_format,
|
||||
/*counts_include_padding=*/false);
|
||||
|
||||
ComputeAndCompareR4<float>(&builder, {{{{3, 3, 3}}}}, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(PoolingTest, AvgPool2DWithPaddingAndStride) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
XlaOp input = ConstantR4FromArray4D<float>(
|
||||
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
|
||||
auto data_format = MakeNCHWFormat(2);
|
||||
auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
|
||||
auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format);
|
||||
auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kSame,
|
||||
data_format);
|
||||
AvgPool(input, kernel_size, stride, padding, data_format,
|
||||
/*counts_include_padding=*/false);
|
||||
|
||||
ComputeAndCompareR4<float>(&builder,
|
||||
{{{{3, 3, 3, 3, 3}, {4.5, 3.5, 2.5, 1.5, 1}}}}, {},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(PoolingTest, AvgPool2DWithGeneralPaddingCountNotIncludePadding) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
XlaOp input = ConstantR4FromArray4D<float>(
|
||||
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
|
||||
auto data_format = MakeNCHWFormat(2);
|
||||
auto kernel_size = ExpandWithBatchAndFeatureDimensions({3, 3}, data_format);
|
||||
auto stride = kernel_size;
|
||||
AvgPool(input, kernel_size, stride, {{1, 1}, {2, 1}}, data_format,
|
||||
/*counts_include_padding=*/false);
|
||||
|
||||
ComputeAndCompareR4<float>(&builder, {{{{3, 3}}}}, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(PoolingTest,
|
||||
AvgPool2DWithGeneralPaddingCountNotIncludePaddingAndStride) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
XlaOp input = ConstantR4FromArray4D<float>(
|
||||
&builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}});
|
||||
auto data_format = MakeNCHWFormat(2);
|
||||
auto kernel_size = ExpandWithBatchAndFeatureDimensions({3, 3}, data_format);
|
||||
auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
|
||||
AvgPool(input, kernel_size, stride, {{2, 1}, {1, 1}}, data_format,
|
||||
/*counts_include_padding=*/false);
|
||||
|
||||
ComputeAndCompareR4<float>(&builder, {{{{1.5, 3, 4.5}, {3, 3, 3}}}}, {},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
|
@ -14,7 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/sorting.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
|
@ -303,7 +303,7 @@ StatusOr<std::unique_ptr<Literal>> LocalClient::TransferFromOutfeedLocal(
|
||||
const Shape& shape, int device_ordinal) {
|
||||
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
|
||||
backend().stream_executor(device_ordinal));
|
||||
auto literal = MakeUnique<Literal>();
|
||||
auto literal = Literal::CreateFromShape(shape);
|
||||
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
|
||||
executor, shape, literal.get()));
|
||||
return std::move(literal);
|
||||
|
@ -45,21 +45,6 @@ int64 GetUniqueId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
// Returns true if an instruction with the given opcode can be the root of the
|
||||
// computation.
|
||||
bool CanBeRoot(HloOpcode opcode) {
|
||||
switch (opcode) {
|
||||
case HloOpcode::kAfterAll:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kTrace:
|
||||
return false;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaOp operator-(const XlaOp& x) { return Neg(x); }
|
||||
@ -142,28 +127,13 @@ XlaOp XlaBuilder::ReportErrorOrReturn(
|
||||
return ReportErrorOrReturn(op_creator());
|
||||
}
|
||||
|
||||
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) const {
|
||||
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
|
||||
TF_RETURN_IF_ERROR(first_error_);
|
||||
|
||||
TF_RET_CHECK(root_id != nullptr);
|
||||
TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size()));
|
||||
|
||||
ProgramShape program_shape;
|
||||
|
||||
// Not all instructions can be roots. Walk backwards from the last added
|
||||
// instruction until a valid root is found.
|
||||
int64 index = instructions_.size() - 1;
|
||||
for (; index >= 0; index--) {
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
|
||||
StringToHloOpcode(instructions_[index].opcode()));
|
||||
if (CanBeRoot(opcode)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (index < 0) {
|
||||
return FailedPrecondition("no root instruction was found");
|
||||
}
|
||||
*root_id = instructions_[index].id();
|
||||
*program_shape.mutable_result() = instructions_[index].shape();
|
||||
*program_shape.mutable_result() = instructions_[root_id].shape();
|
||||
|
||||
// Check that the parameter numbers are continuous from 0, and add parameter
|
||||
// shapes and names to the program shape.
|
||||
@ -188,8 +158,15 @@ StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) const {
|
||||
}
|
||||
|
||||
StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const {
|
||||
int64 root;
|
||||
return GetProgramShape(&root);
|
||||
TF_RET_CHECK(!instructions_.empty());
|
||||
return GetProgramShape(instructions_.back().id());
|
||||
}
|
||||
|
||||
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(XlaOp root) const {
|
||||
if (root.builder_ != this) {
|
||||
return InvalidArgument("Given root operation is not in this computation.");
|
||||
}
|
||||
return GetProgramShape(root.handle());
|
||||
}
|
||||
|
||||
void XlaBuilder::IsConstantVisitor(const int64 op_handle,
|
||||
@ -257,17 +234,29 @@ StatusOr<XlaComputation> XlaBuilder::Build() {
|
||||
first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
|
||||
return AppendStatus(first_error_, backtrace);
|
||||
}
|
||||
return Build(instructions_.back().id());
|
||||
}
|
||||
|
||||
StatusOr<XlaComputation> XlaBuilder::Build(XlaOp root) {
|
||||
if (root.builder_ != this) {
|
||||
return InvalidArgument("Given root operation is not in this computation.");
|
||||
}
|
||||
return Build(root.handle());
|
||||
}
|
||||
|
||||
StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
|
||||
if (!first_error_.ok()) {
|
||||
string backtrace;
|
||||
first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
|
||||
return AppendStatus(first_error_, backtrace);
|
||||
}
|
||||
|
||||
HloComputationProto entry;
|
||||
entry.set_id(GetUniqueId()); // Give the computation a global unique id.
|
||||
entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique.
|
||||
|
||||
{
|
||||
int64 root_id;
|
||||
TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(),
|
||||
GetProgramShape(&root_id));
|
||||
TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), GetProgramShape(root_id));
|
||||
entry.set_root_id(root_id);
|
||||
}
|
||||
|
||||
for (auto& instruction : instructions_) {
|
||||
// Ensures that the instruction names are unique among the whole graph.
|
||||
@ -1099,11 +1088,11 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
|
||||
sharding_builder::AssignDevice(0);
|
||||
XlaScopedShardingAssignment scoped_sharding(this,
|
||||
infeed_instruction_sharding);
|
||||
TF_ASSIGN_OR_RETURN(infeed,
|
||||
AddInstruction(std::move(instr), HloOpcode::kInfeed));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
infeed, AddInstruction(std::move(instr), HloOpcode::kInfeed, {}));
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(infeed,
|
||||
AddInstruction(std::move(instr), HloOpcode::kInfeed));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
infeed, AddInstruction(std::move(instr), HloOpcode::kInfeed, {}));
|
||||
}
|
||||
|
||||
// The infeed instruction produces a tuple of the infed data and a token
|
||||
@ -1892,6 +1881,61 @@ XlaOp XlaBuilder::CrossReplicaSum(
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension,
|
||||
int64 concat_dimension, int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
|
||||
// The HloInstruction for Alltoall currently only handles the data
|
||||
// communication: it accepts N already split parts and scatters them to N
|
||||
// cores, and each core gathers the N received parts into a tuple as the
|
||||
// output. So here we explicitly split the operand before the hlo alltoall,
|
||||
// and concat the tuple elements.
|
||||
//
|
||||
// First, run shape inference to make sure the shapes are valid.
|
||||
TF_RETURN_IF_ERROR(
|
||||
ShapeInference::InferAllToAllShape(operand_shape, split_dimension,
|
||||
concat_dimension, split_count)
|
||||
.status());
|
||||
|
||||
// Split into N parts.
|
||||
std::vector<XlaOp> slices;
|
||||
slices.reserve(split_count);
|
||||
const int64 block_size =
|
||||
operand_shape.dimensions(split_dimension) / split_count;
|
||||
for (int i = 0; i < split_count; i++) {
|
||||
slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size,
|
||||
/*limit_index=*/(i + 1) * block_size,
|
||||
/*stride=*/1, /*dimno=*/split_dimension));
|
||||
}
|
||||
|
||||
// Handle data communication.
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices));
|
||||
std::vector<const Shape*> slice_shape_ptrs;
|
||||
c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs),
|
||||
[](const Shape& shape) { return &shape; });
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
|
||||
for (const ReplicaGroup& group : replica_groups) {
|
||||
*instr.add_replica_groups() = group;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
XlaOp alltoall,
|
||||
AddInstruction(std::move(instr), HloOpcode::kAllToAll, slices));
|
||||
|
||||
// Concat the N received parts.
|
||||
std::vector<XlaOp> received;
|
||||
received.reserve(split_count);
|
||||
for (int i = 0; i < split_count; i++) {
|
||||
received.push_back(this->GetTupleElement(alltoall, i));
|
||||
}
|
||||
return this->ConcatInDim(received, concat_dimension);
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::SelectAndScatter(
|
||||
const XlaOp& operand, const XlaComputation& select,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
@ -2163,11 +2207,6 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
|
||||
LookUpInstruction(root_op));
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
|
||||
if (!CanBeRoot(opcode)) {
|
||||
return InvalidArgument("the operand with opcode %s cannot be root",
|
||||
root->opcode().c_str());
|
||||
}
|
||||
|
||||
HloComputationProto entry;
|
||||
entry.set_id(GetUniqueId()); // Give the computation a global unique id.
|
||||
@ -2693,6 +2732,13 @@ XlaOp CrossReplicaSum(
|
||||
replica_group_ids, channel_id);
|
||||
}
|
||||
|
||||
XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
|
||||
int64 concat_dimension, int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups) {
|
||||
return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
|
||||
split_count, replica_groups);
|
||||
}
|
||||
|
||||
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
|
@ -195,9 +195,14 @@ class XlaBuilder {
|
||||
|
||||
// Builds the computation with the requested operations, or returns a non-ok
|
||||
// status. Note that all ops that have been enqueued will be moved to the
|
||||
// computation being returned.
|
||||
// computation being returned. The root of the computation will be the last
|
||||
// added operation.
|
||||
StatusOr<XlaComputation> Build();
|
||||
|
||||
// Overload of Build which specifies a particular root instruction for the
|
||||
// computation.
|
||||
StatusOr<XlaComputation> Build(XlaOp root);
|
||||
|
||||
// Builds the computation with the requested operations, or notes an error in
|
||||
// the parent XlaBuilder and returns an empty computation if building failed.
|
||||
// This function is intended to be used where the returned XlaComputation is
|
||||
@ -225,9 +230,14 @@ class XlaBuilder {
|
||||
// Returns the shape of the given op.
|
||||
StatusOr<Shape> GetShape(const XlaOp& op) const;
|
||||
|
||||
// Returns the (inferred) result for the current computation's shape.
|
||||
// Returns the (inferred) result for the current computation's shape. This
|
||||
// assumes the root instruction is the last added instruction.
|
||||
StatusOr<ProgramShape> GetProgramShape() const;
|
||||
|
||||
// Returns the (inferred) result for the current computation's shape using the
|
||||
// given operation as the root.
|
||||
StatusOr<ProgramShape> GetProgramShape(XlaOp root) const;
|
||||
|
||||
// Reports an error to the builder, by
|
||||
// * storing it internally and capturing a backtrace if it's the first error
|
||||
// (this deferred value will be produced on the call to
|
||||
@ -255,6 +265,9 @@ class XlaBuilder {
|
||||
StatusOr<bool> IsConstant(const XlaOp& operand) const;
|
||||
|
||||
private:
|
||||
// Build helper which takes the id of the root operation..
|
||||
StatusOr<XlaComputation> Build(int64 root_id);
|
||||
|
||||
// Enqueues a "retrieve parameter value" instruction for a parameter that was
|
||||
// passed to the computation.
|
||||
XlaOp Parameter(int64 parameter_number, const Shape& shape,
|
||||
@ -686,9 +699,9 @@ class XlaBuilder {
|
||||
// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
|
||||
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
|
||||
//
|
||||
// - `channel_id`: for Allreduce nodes from different models, if they have the
|
||||
// same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
|
||||
// applied cross models.
|
||||
// - `channel_id`: for Allreduce nodes from different modules, if they have
|
||||
// the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will
|
||||
// not be applied cross modules.
|
||||
//
|
||||
// TODO(b/79737069): Rename this to AllReduce when it's ready to use.
|
||||
XlaOp CrossReplicaSum(
|
||||
@ -697,6 +710,13 @@ class XlaBuilder {
|
||||
const tensorflow::gtl::optional<ChannelHandle>& channel_id =
|
||||
tensorflow::gtl::nullopt);
|
||||
|
||||
// Enqueues an operation that do an Alltoall of the operand cross cores.
|
||||
//
|
||||
// TODO(b/110096724): This is NOT YET ready to use.
|
||||
XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
|
||||
int64 concat_dimension, int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups);
|
||||
|
||||
// Enqueues an operation that scatters the `source` array to the selected
|
||||
// indices of each window.
|
||||
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
|
||||
@ -969,9 +989,8 @@ class XlaBuilder {
|
||||
// shape.
|
||||
StatusOr<XlaOp> Reshape(const Shape& shape, const XlaOp& operand);
|
||||
|
||||
// Returns the (inferred) result for the program shape for the current
|
||||
// computation and fills the root_id in the pointer.
|
||||
StatusOr<ProgramShape> GetProgramShape(int64* root_id) const;
|
||||
// Returns the (inferred) result for the program shape using the given root.
|
||||
StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;
|
||||
|
||||
// Returns shapes for the operands.
|
||||
StatusOr<std::vector<Shape>> GetOperandShapes(
|
||||
@ -1234,6 +1253,9 @@ class XlaBuilder {
|
||||
const XlaOp& operand, const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
|
||||
const tensorflow::gtl::optional<ChannelHandle>& channel_id);
|
||||
friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
|
||||
int64 concat_dimension, int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups);
|
||||
friend XlaOp SelectAndScatter(
|
||||
const XlaOp& operand, const XlaComputation& select,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
@ -1820,9 +1842,9 @@ XlaOp CrossReplicaSum(
|
||||
// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
|
||||
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
|
||||
//
|
||||
// - `channel_id`: for Allreduce nodes from different models, if they have the
|
||||
// - `channel_id`: for Allreduce nodes from different modules, if they have the
|
||||
// same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
|
||||
// applied cross models.
|
||||
// applied cross modules.
|
||||
//
|
||||
// TODO(b/79737069): Rename this to AllReduce when it's ready to use.
|
||||
XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation,
|
||||
@ -1830,6 +1852,13 @@ XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation,
|
||||
const tensorflow::gtl::optional<ChannelHandle>&
|
||||
channel_id = tensorflow::gtl::nullopt);
|
||||
|
||||
// Enqueues an operation that do an Alltoall of the operand cross cores.
|
||||
//
|
||||
// TODO(b/110096724): This is NOT YET ready to use.
|
||||
XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
|
||||
int64 concat_dimension, int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups = {});
|
||||
|
||||
// Enqueues an operation that scatters the `source` array to the selected
|
||||
// indices of each window.
|
||||
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
@ -46,6 +47,17 @@ class XlaBuilderTest : public ::testing::Test {
|
||||
return HloModule::CreateFromProto(proto, config);
|
||||
}
|
||||
|
||||
// Overload which explicitly specifies the root instruction.
|
||||
StatusOr<std::unique_ptr<HloModule>> BuildHloModule(XlaBuilder* b,
|
||||
XlaOp root) {
|
||||
TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build(root));
|
||||
const HloModuleProto& proto = computation.proto();
|
||||
TF_ASSIGN_OR_RETURN(const auto& config,
|
||||
HloModule::CreateModuleConfigFromProto(
|
||||
proto, legacy_flags::GetDebugOptionsFromFlags()));
|
||||
return HloModule::CreateFromProto(proto, config);
|
||||
}
|
||||
|
||||
// Returns the name of the test currently being run.
|
||||
string TestName() const {
|
||||
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
|
||||
@ -293,6 +305,21 @@ TEST_F(XlaBuilderTest, Transpose) {
|
||||
EXPECT_THAT(root, op::Transpose(op::Parameter()));
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, AllToAll) {
|
||||
XlaBuilder b(TestName());
|
||||
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
|
||||
AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0,
|
||||
/*split_count=*/2);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
|
||||
// AllToAll is decomposed into slices -> all-to-all -> gte -> concat.
|
||||
EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate);
|
||||
EXPECT_EQ(root->operand(0)->operand(0)->opcode(), HloOpcode::kAllToAll);
|
||||
EXPECT_TRUE(
|
||||
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8})));
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, ReportError) {
|
||||
XlaBuilder b(TestName());
|
||||
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
|
||||
@ -320,5 +347,45 @@ TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) {
|
||||
EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error"));
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, BuildWithSpecificRoot) {
|
||||
XlaBuilder b(TestName());
|
||||
XlaOp constant = ConstantR0<float>(&b, 1.0);
|
||||
Add(constant, ConstantR0<float>(&b, 2.0));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/constant));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Constant());
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, BuildWithSpecificRootAndMultipleParameters) {
|
||||
// Specifying a particular root in Build should still include all entry
|
||||
// parameters.
|
||||
XlaBuilder b(TestName());
|
||||
const Shape shape = ShapeUtil::MakeShape(F32, {42, 123});
|
||||
XlaOp x = Parameter(&b, 0, shape, "x");
|
||||
XlaOp y = Parameter(&b, 1, shape, "y");
|
||||
XlaOp z = Parameter(&b, 2, shape, "z");
|
||||
Add(x, Sub(y, z));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/x));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Parameter());
|
||||
EXPECT_EQ(module->entry_computation()->num_parameters(), 3);
|
||||
EXPECT_EQ(module->entry_computation()->instruction_count(), 5);
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) {
|
||||
XlaBuilder b(TestName());
|
||||
XlaBuilder other_b(TestName());
|
||||
const Shape shape = ShapeUtil::MakeShape(F32, {42, 123});
|
||||
|
||||
Parameter(&b, 0, shape, "param");
|
||||
XlaOp other_param = Parameter(&other_b, 0, shape, "other_param");
|
||||
|
||||
Status status = b.Build(other_param).status();
|
||||
ASSERT_IS_NOT_OK(status);
|
||||
EXPECT_THAT(
|
||||
status.error_message(),
|
||||
::testing::HasSubstr("root operation is not in this computation"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -1,33 +0,0 @@
|
||||
# Description:
|
||||
# The new XLA client libraries.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = [":friends"])
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = [
|
||||
"//tensorflow/compiler/xla:friends",
|
||||
],
|
||||
)
|
||||
|
||||
# Filegroup used to collect source files for dependency checking.
|
||||
filegroup(
|
||||
name = "c_srcs",
|
||||
data = glob([
|
||||
"**/*.cc",
|
||||
"**/*.h",
|
||||
]),
|
||||
)
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
cc_library(
|
||||
name = "xla_builder",
|
||||
hdrs = ["xla_builder.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
],
|
||||
)
|
@ -71,7 +71,7 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal) {
|
||||
return out;
|
||||
}
|
||||
|
||||
Literal::StrideConfig::StrideConfig(
|
||||
MutableLiteralBase::StrideConfig::StrideConfig(
|
||||
const Shape& source_shape, const Shape& dest_shape,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions)
|
||||
: dimensions(dimensions),
|
||||
@ -133,7 +133,8 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
|
||||
}
|
||||
|
||||
Literal::Literal(const Shape& shape, bool allocate_arrays)
|
||||
: LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
|
||||
: MutableLiteralBase() {
|
||||
shape_ = MakeUnique<Shape>(shape);
|
||||
CHECK(LayoutUtil::HasLayout(*shape_));
|
||||
root_piece_ = new Piece();
|
||||
root_piece_->set_subshape(shape_.get());
|
||||
@ -159,7 +160,9 @@ void Literal::DeallocateBuffers() {
|
||||
});
|
||||
}
|
||||
|
||||
Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); }
|
||||
Literal::Literal(Literal&& other) : MutableLiteralBase() {
|
||||
*this = std::move(other);
|
||||
}
|
||||
|
||||
Literal& Literal::operator=(Literal&& other) {
|
||||
DCHECK(&other.root_piece_->subshape() == other.shape_.get());
|
||||
@ -187,12 +190,13 @@ const SparseIndexArray* LiteralBase::sparse_indices(
|
||||
return piece(shape_index).sparse_indices();
|
||||
}
|
||||
|
||||
SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) {
|
||||
SparseIndexArray* MutableLiteralBase::sparse_indices(
|
||||
const ShapeIndex& shape_index) {
|
||||
return piece(shape_index).sparse_indices();
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
Status Literal::CopySliceFromInternal(
|
||||
Status MutableLiteralBase::CopySliceFromInternal(
|
||||
const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
|
||||
tensorflow::gtl::ArraySlice<int64> dest_base,
|
||||
tensorflow::gtl::ArraySlice<int64> copy_size) {
|
||||
@ -225,7 +229,7 @@ Status Literal::CopySliceFromInternal(
|
||||
// proper stride size at the matching dimension.
|
||||
DimensionVector src_indexes(src_base.size(), 0);
|
||||
DimensionVector dest_indexes(dest_base.size(), 0);
|
||||
Literal::StrideConfig stride_config(src_literal.shape(), shape(),
|
||||
MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(),
|
||||
copy_size);
|
||||
|
||||
auto copy_proc = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
|
||||
@ -253,7 +257,8 @@ Status Literal::CopySliceFromInternal(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Literal::CopyElementFrom(const LiteralSlice& src_literal,
|
||||
Status MutableLiteralBase::CopyElementFrom(
|
||||
const LiteralSlice& src_literal,
|
||||
tensorflow::gtl::ArraySlice<int64> src_index,
|
||||
tensorflow::gtl::ArraySlice<int64> dest_index) {
|
||||
DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
|
||||
@ -275,8 +280,8 @@ Status Literal::CopyElementFrom(const LiteralSlice& src_literal,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto(
|
||||
const LiteralProto& proto) {
|
||||
/* static */ StatusOr<std::unique_ptr<Literal>>
|
||||
MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
|
||||
if (!proto.has_shape()) {
|
||||
return InvalidArgument("LiteralProto has no shape");
|
||||
}
|
||||
@ -405,7 +410,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Literal::CopyFrom(const LiteralSlice& src_literal,
|
||||
Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
|
||||
const ShapeIndex& dest_shape_index,
|
||||
const ShapeIndex& src_shape_index) {
|
||||
const Shape& dest_subshape =
|
||||
@ -482,7 +487,8 @@ Status Literal::MoveFrom(Literal&& src_literal,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
|
||||
Status MutableLiteralBase::CopySliceFrom(
|
||||
const LiteralSlice& src_literal,
|
||||
tensorflow::gtl::ArraySlice<int64> src_base,
|
||||
tensorflow::gtl::ArraySlice<int64> dest_base,
|
||||
tensorflow::gtl::ArraySlice<int64> copy_size) {
|
||||
@ -543,7 +549,7 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
|
||||
shape().element_type());
|
||||
}
|
||||
|
||||
void Literal::PopulateR1(const tensorflow::core::Bitmap& values) {
|
||||
void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
|
||||
CHECK(ShapeUtil::IsArray(shape()));
|
||||
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
|
||||
CHECK_EQ(element_count(), values.bits());
|
||||
@ -895,8 +901,8 @@ size_t LiteralBase::Hash() const {
|
||||
return hash_value;
|
||||
}
|
||||
|
||||
Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
|
||||
int64 value) {
|
||||
Status MutableLiteralBase::SetIntegralAsS64(
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index, int64 value) {
|
||||
CHECK(LayoutUtil::IsDenseArray(shape()));
|
||||
switch (shape().element_type()) {
|
||||
case PRED:
|
||||
@ -933,7 +939,7 @@ tensorflow::gtl::ArraySlice<int64> LiteralBase::GetSparseIndex(
|
||||
return p.sparse_indices()->At(sparse_element_number);
|
||||
}
|
||||
|
||||
void Literal::SortSparseElements(const ShapeIndex& shape_index) {
|
||||
void MutableLiteralBase::SortSparseElements(const ShapeIndex& shape_index) {
|
||||
piece(shape_index).SortSparseElements();
|
||||
}
|
||||
|
||||
@ -1391,11 +1397,11 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
|
||||
elements.push_back(std::move(*new_element));
|
||||
}
|
||||
auto converted = MakeUnique<Literal>();
|
||||
*converted = Literal::MoveIntoTuple(&elements);
|
||||
*converted = MutableLiteralBase::MoveIntoTuple(&elements);
|
||||
return std::move(converted);
|
||||
}
|
||||
|
||||
/* static */ Literal Literal::MoveIntoTuple(
|
||||
/* static */ Literal MutableLiteralBase::MoveIntoTuple(
|
||||
tensorflow::gtl::MutableArraySlice<Literal> elements) {
|
||||
std::vector<Shape> element_shapes;
|
||||
for (const Literal& element : elements) {
|
||||
@ -1808,7 +1814,8 @@ Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
|
||||
} // namespace
|
||||
|
||||
Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
|
||||
// These conditions should have been checked in Literal::CreateFromProto.
|
||||
// These conditions should have been checked in
|
||||
// MutableLiteralBase::CreateFromProto.
|
||||
TF_RET_CHECK(proto.has_shape());
|
||||
TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
|
||||
TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape()));
|
||||
@ -1900,7 +1907,7 @@ const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
|
||||
return piece(shape_index).untyped_data();
|
||||
}
|
||||
|
||||
void* Literal::untyped_data(const ShapeIndex& shape_index) {
|
||||
void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) {
|
||||
return piece(shape_index).untyped_data();
|
||||
}
|
||||
|
||||
@ -1916,6 +1923,127 @@ string LiteralBase::GetR1U8AsString() const {
|
||||
ShapeUtil::ElementsIn(shape()));
|
||||
}
|
||||
|
||||
void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape,
|
||||
Piece* src_piece,
|
||||
Piece* dest_piece) {
|
||||
DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape()))
|
||||
<< "src_piece has shape: "
|
||||
<< ShapeUtil::HumanString(src_piece->subshape())
|
||||
<< "dest_piece has shape: "
|
||||
<< ShapeUtil::HumanString(dest_piece->subshape());
|
||||
if (ShapeUtil::IsTuple(shape)) {
|
||||
for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
|
||||
const Shape& subshape = shape.tuple_shapes(i);
|
||||
|
||||
auto child_piece = Piece();
|
||||
child_piece.set_subshape(&subshape);
|
||||
|
||||
CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece);
|
||||
|
||||
dest_piece->emplace_back(std::move(child_piece));
|
||||
}
|
||||
} else if (ShapeUtil::IsArray(shape)) {
|
||||
dest_piece->set_buffer(src_piece->buffer());
|
||||
} else {
|
||||
// If the shape is neither an array nor tuple, then it must be
|
||||
// zero-sized. Otherwise, some memory needs to be allocated for it.
|
||||
CHECK_EQ(dest_piece->size_bytes(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
MutableLiteralBase::~MutableLiteralBase() {}
|
||||
|
||||
MutableBorrowingLiteral::MutableBorrowingLiteral(
|
||||
const MutableBorrowingLiteral& literal)
|
||||
: MutableLiteralBase() {
|
||||
shape_ = MakeUnique<Shape>(literal.shape());
|
||||
CHECK(LayoutUtil::HasLayout(*shape_));
|
||||
|
||||
root_piece_ = new Piece();
|
||||
root_piece_->set_subshape(shape_.get());
|
||||
|
||||
CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
|
||||
}
|
||||
|
||||
MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
|
||||
const MutableBorrowingLiteral& literal) {
|
||||
shape_ = MakeUnique<Shape>(literal.shape());
|
||||
CHECK(LayoutUtil::HasLayout(*shape_));
|
||||
|
||||
root_piece_ = new Piece();
|
||||
root_piece_->set_subshape(shape_.get());
|
||||
|
||||
CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
MutableBorrowingLiteral::MutableBorrowingLiteral(
|
||||
const MutableLiteralBase& literal)
|
||||
: MutableLiteralBase() {
|
||||
shape_ = MakeUnique<Shape>(literal.shape());
|
||||
CHECK(LayoutUtil::HasLayout(*shape_));
|
||||
|
||||
root_piece_ = new Piece();
|
||||
root_piece_->set_subshape(shape_.get());
|
||||
|
||||
CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
|
||||
}
|
||||
|
||||
MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
|
||||
: MutableLiteralBase() {
|
||||
shape_ = MakeUnique<Shape>(literal->shape());
|
||||
CHECK(LayoutUtil::HasLayout(*shape_));
|
||||
|
||||
root_piece_ = new Piece();
|
||||
root_piece_->set_subshape(shape_.get());
|
||||
|
||||
CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_);
|
||||
}
|
||||
|
||||
MutableBorrowingLiteral::MutableBorrowingLiteral(
|
||||
MutableBorrowingLiteral literal, const ShapeIndex& view_root)
|
||||
: MutableLiteralBase() {
|
||||
shape_ = MakeUnique<Shape>(literal.piece(view_root).subshape());
|
||||
CHECK(LayoutUtil::HasLayout(*shape_));
|
||||
|
||||
root_piece_ = new Piece();
|
||||
root_piece_->set_subshape(shape_.get());
|
||||
|
||||
CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_);
|
||||
}
|
||||
|
||||
MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
|
||||
const Shape& shape)
|
||||
: MutableLiteralBase() {
|
||||
shape_ = MakeUnique<Shape>(shape);
|
||||
CHECK(LayoutUtil::HasLayout(*shape_));
|
||||
CHECK(!ShapeUtil::IsTuple(*shape_));
|
||||
|
||||
root_piece_ = new Piece();
|
||||
root_piece_->set_buffer(const_cast<char*>(src_buf_ptr));
|
||||
root_piece_->set_subshape(shape_.get());
|
||||
}
|
||||
|
||||
MutableBorrowingLiteral::~MutableBorrowingLiteral() {
|
||||
if (root_piece_ != nullptr) {
|
||||
root_piece_->ForEachMutableSubpiece(
|
||||
[&](const ShapeIndex& index, Piece* piece) {
|
||||
if (piece->buffer() != nullptr) {
|
||||
delete piece->sparse_indices();
|
||||
}
|
||||
});
|
||||
delete root_piece_;
|
||||
}
|
||||
}
|
||||
|
||||
LiteralSlice::LiteralSlice(const LiteralBase& literal)
|
||||
: LiteralBase(), root_piece_(&literal.root_piece()) {}
|
||||
|
||||
LiteralSlice::LiteralSlice(const LiteralBase& literal,
|
||||
const ShapeIndex& view_root)
|
||||
: LiteralBase(), root_piece_(&literal.piece(view_root)) {}
|
||||
|
||||
void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
|
||||
CHECK(ShapeUtil::IsTuple(shape));
|
||||
for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
|
||||
@ -1932,13 +2060,6 @@ void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
|
||||
}
|
||||
}
|
||||
|
||||
LiteralSlice::LiteralSlice(const LiteralBase& literal)
|
||||
: LiteralBase(), root_piece_(&literal.root_piece()) {}
|
||||
|
||||
LiteralSlice::LiteralSlice(const LiteralBase& literal,
|
||||
const ShapeIndex& view_root)
|
||||
: LiteralBase(), root_piece_(&literal.piece(view_root)) {}
|
||||
|
||||
BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
|
||||
: LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
|
||||
CHECK(ShapeUtil::IsArray(*shape_));
|
||||
|
@ -310,9 +310,10 @@ class LiteralBase {
|
||||
// type of literal itself (0 for numeric types, and false for predicates).
|
||||
//
|
||||
// Note: It's an antipattern to use this method then immediately call
|
||||
// Literal::Populate on the result (since that results in zero initialization,
|
||||
// then reinitialization. Conside if a call to MakeUnique<Literal>(shape),
|
||||
// followed by the call to Literal::Populate can be used instead.
|
||||
// MutableLiteralBase::Populate on the result (since that results in zero
|
||||
// initialization, then reinitialization. Conside if a call to
|
||||
// MakeUnique<Literal>(shape), followed by the call to
|
||||
// MutableLiteralBase::Populate can be used instead.
|
||||
static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
|
||||
|
||||
protected:
|
||||
@ -534,7 +535,7 @@ class LiteralBase {
|
||||
virtual const Piece& root_piece() const = 0;
|
||||
|
||||
// LiteralSlice and Literal must access Pieces of other Literals.
|
||||
friend class Literal;
|
||||
friend class MutableLiteralBase;
|
||||
friend class LiteralSlice;
|
||||
friend class BorrowingLiteral;
|
||||
|
||||
@ -545,33 +546,10 @@ class LiteralBase {
|
||||
tensorflow::gtl::ArraySlice<int64> start_indices) const;
|
||||
};
|
||||
|
||||
// Class representing literal values in XLA.
|
||||
//
|
||||
// The underlying buffer and shape is always owned by this class.
|
||||
class Literal : public LiteralBase {
|
||||
// Abstract base class representing a mutable literal in XLA.
|
||||
class MutableLiteralBase : public LiteralBase {
|
||||
public:
|
||||
Literal() : Literal(ShapeUtil::MakeNil()) {}
|
||||
|
||||
// Create a literal of the given shape. The literal is allocated sufficient
|
||||
// memory to hold the shape. Memory is uninitialized.
|
||||
explicit Literal(const Shape& shape);
|
||||
virtual ~Literal();
|
||||
|
||||
// Literals are moveable, but not copyable. To copy a literal use
|
||||
// Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
|
||||
// of literals which can be expensive.
|
||||
Literal(const Literal& other) = delete;
|
||||
Literal& operator=(const Literal& other) = delete;
|
||||
Literal(Literal&& other);
|
||||
// 'allocate_arrays' indicates whether to allocate memory for the arrays in
|
||||
// the shape. If false, buffer pointers inside of the Literal::Pieces are set
|
||||
// to nullptr.
|
||||
Literal(const Shape& shape, bool allocate_arrays);
|
||||
Literal& operator=(Literal&& other);
|
||||
|
||||
// TODO(b/67651157): Remove this accessor. Literal users should not be able to
|
||||
// mutate the shape as this can produce malformed Literals.
|
||||
Shape* mutable_shape_do_not_use() { return shape_.get(); }
|
||||
virtual ~MutableLiteralBase() = 0;
|
||||
|
||||
// Returns a MutableArraySlice view of the array for this literal for the
|
||||
// given NativeT (e.g., float). CHECKs if the subshape of the literal at the
|
||||
@ -587,6 +565,10 @@ class Literal : public LiteralBase {
|
||||
// is not a sparse array.
|
||||
SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
|
||||
|
||||
// TODO(b/67651157): Remove this accessor. Literal users should not be able to
|
||||
// mutate the shape as this can produce malformed Literals.
|
||||
Shape* mutable_shape_do_not_use() { return shape_.get(); }
|
||||
|
||||
// Returns a pointer to the underlying buffer holding the array at the given
|
||||
// shape index. CHECKs if the subshape of the literal at the given ShapeIndex
|
||||
// is not array.
|
||||
@ -613,21 +595,6 @@ class Literal : public LiteralBase {
|
||||
const ShapeIndex& dest_shape_index = {},
|
||||
const ShapeIndex& src_shape_index = {});
|
||||
|
||||
// Returns a vector containing the tuple elements of this Literal as separate
|
||||
// Literals. This Literal must be tuple-shaped and can be a nested tuple. The
|
||||
// elements are moved into the new Literals; no data is copied. Upon return
|
||||
// this Literal is set to a nil shape (empty tuple)
|
||||
std::vector<Literal> DecomposeTuple();
|
||||
|
||||
// Similar to CopyFrom, but with move semantincs. The subshape of this literal
|
||||
// rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
|
||||
// (layouts and shapes must match), but need not be arrays. The memory
|
||||
// allocated in this literal for the subshape at dest_shape_index is
|
||||
// deallocated, and the respective buffers are replaced with those in
|
||||
// src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
|
||||
Status MoveFrom(Literal&& src_literal,
|
||||
const ShapeIndex& dest_shape_index = {});
|
||||
|
||||
// Copies the values from src_literal, starting at src_base shape indexes,
|
||||
// to this literal, starting at dest_base, where the copy size in each
|
||||
// dimension is specified by copy_size.
|
||||
@ -730,12 +697,7 @@ class Literal : public LiteralBase {
|
||||
static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
|
||||
const LiteralProto& proto);
|
||||
|
||||
private:
|
||||
// Recursively sets the subshapes and buffers of all subpieces rooted at
|
||||
// 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
|
||||
// the shape.
|
||||
void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
|
||||
|
||||
protected:
|
||||
// Returns the piece at the given ShapeIndex.
|
||||
Piece& piece(const ShapeIndex& shape_index) {
|
||||
return const_cast<Piece&>(LiteralBase::piece(shape_index));
|
||||
@ -783,12 +745,83 @@ class Literal : public LiteralBase {
|
||||
template <typename NativeT, typename FnType>
|
||||
Status PopulateInternal(const FnType& generator, bool parallel);
|
||||
|
||||
friend class LiteralBase;
|
||||
friend class MutableBorrowingLiteral;
|
||||
};
|
||||
std::ostream& operator<<(std::ostream& out, const Literal& literal);
|
||||
|
||||
// The underlying buffer and shape is always owned by this class.
|
||||
class Literal : public MutableLiteralBase {
|
||||
public:
|
||||
Literal() : Literal(ShapeUtil::MakeNil()) {}
|
||||
|
||||
// Create a literal of the given shape. The literal is allocated sufficient
|
||||
// memory to hold the shape. Memory is uninitialized.
|
||||
explicit Literal(const Shape& shape);
|
||||
virtual ~Literal();
|
||||
|
||||
// Literals are moveable, but not copyable. To copy a literal use
|
||||
// Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
|
||||
// of literals which can be expensive.
|
||||
Literal(const Literal& other) = delete;
|
||||
Literal& operator=(const Literal& other) = delete;
|
||||
Literal(Literal&& other);
|
||||
// 'allocate_arrays' indicates whether to allocate memory for the arrays in
|
||||
// the shape. If false, buffer pointers inside of the Literal::Pieces are set
|
||||
// to nullptr.
|
||||
Literal(const Shape& shape, bool allocate_arrays);
|
||||
Literal& operator=(Literal&& other);
|
||||
|
||||
// Similar to CopyFrom, but with move semantincs. The subshape of this literal
|
||||
// rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
|
||||
// (layouts and shapes must match), but need not be arrays. The memory
|
||||
// allocated in this literal for the subshape at dest_shape_index is
|
||||
// deallocated, and the respective buffers are replaced with those in
|
||||
// src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
|
||||
virtual Status MoveFrom(Literal&& src_literal,
|
||||
const ShapeIndex& dest_shape_index = {});
|
||||
|
||||
// Returns a vector containing the tuple elements of this Literal as separate
|
||||
// Literals. This Literal must be tuple-shaped and can be a nested tuple. The
|
||||
// elements are moved into the new Literals; no data is copied. Upon return
|
||||
// this Literal is set to a nil shape (empty tuple)
|
||||
std::vector<Literal> DecomposeTuple();
|
||||
|
||||
private:
|
||||
// Deallocate the buffers held by this literal.
|
||||
void DeallocateBuffers();
|
||||
|
||||
friend class LiteralBase;
|
||||
// Recursively sets the subshapes and buffers of all subpieces rooted at
|
||||
// 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
|
||||
// the shape.
|
||||
void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
|
||||
};
|
||||
|
||||
// The underlying buffer is not owned by this class and is always owned by
|
||||
// others. The shape is not owned by this class and not mutable.
|
||||
class MutableBorrowingLiteral : public MutableLiteralBase {
|
||||
public:
|
||||
virtual ~MutableBorrowingLiteral();
|
||||
|
||||
MutableBorrowingLiteral() : MutableLiteralBase() {}
|
||||
|
||||
MutableBorrowingLiteral(const MutableBorrowingLiteral& literal);
|
||||
MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal);
|
||||
|
||||
// Implicit conversion constructors.
|
||||
MutableBorrowingLiteral(const MutableLiteralBase& literal);
|
||||
MutableBorrowingLiteral(MutableLiteralBase* literal);
|
||||
MutableBorrowingLiteral(MutableBorrowingLiteral literal,
|
||||
const ShapeIndex& view_root);
|
||||
MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
|
||||
|
||||
private:
|
||||
// Recursively copies the subtree from the `src_piece` at the given child
|
||||
// index to the `dest_piece`. For buffers only the pointers are copied, but
|
||||
// not the content.
|
||||
void CopyPieceSubtree(const Shape& shape, Piece* src_piece,
|
||||
Piece* dest_piece);
|
||||
};
|
||||
std::ostream& operator<<(std::ostream& out, const Literal& literal);
|
||||
|
||||
// A read-only view of a Literal. A LiteralSlice contains pointers to shape and
|
||||
// literal buffers always owned by others.
|
||||
@ -831,9 +864,9 @@ class BorrowingLiteral : public LiteralBase {
|
||||
const Piece& root_piece() const override { return root_piece_; };
|
||||
Piece root_piece_;
|
||||
|
||||
// Shape of this literal. Stored as unique_ptr so such that the (default)
|
||||
// move construction of this class would be trivially correct: the pointer to
|
||||
// Shape root_piece_ stores will still point to the correct address.
|
||||
// Shape of this literal. Stored as unique_ptr such that the (default) move
|
||||
// construction of this class would be trivially correct: the pointer to Shape
|
||||
// root_piece_ stores will still point to the correct address.
|
||||
std::unique_ptr<Shape> shape_;
|
||||
};
|
||||
|
||||
@ -886,7 +919,7 @@ tensorflow::gtl::ArraySlice<NativeT> LiteralBase::data(
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
tensorflow::gtl::MutableArraySlice<NativeT> Literal::data(
|
||||
tensorflow::gtl::MutableArraySlice<NativeT> MutableLiteralBase::data(
|
||||
const ShapeIndex& shape_index) {
|
||||
return piece(shape_index).data<NativeT>();
|
||||
}
|
||||
@ -904,14 +937,15 @@ inline NativeT LiteralBase::Get(
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
|
||||
inline void MutableLiteralBase::Set(
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index,
|
||||
const ShapeIndex& shape_index, NativeT value) {
|
||||
return piece(shape_index).Set<NativeT>(multi_index, value);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
|
||||
NativeT value) {
|
||||
inline void MutableLiteralBase::Set(
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value) {
|
||||
return root_piece().Set<NativeT>(multi_index, value);
|
||||
}
|
||||
|
||||
@ -929,7 +963,7 @@ NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void Literal::AppendSparseElement(
|
||||
void MutableLiteralBase::AppendSparseElement(
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value,
|
||||
const ShapeIndex& shape_index) {
|
||||
Piece& p = piece(shape_index);
|
||||
@ -959,7 +993,8 @@ void LiteralBase::EachCell(
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
|
||||
inline void MutableLiteralBase::PopulateR1(
|
||||
tensorflow::gtl::ArraySlice<NativeT> values) {
|
||||
CHECK(ShapeUtil::IsArray(shape()));
|
||||
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
|
||||
CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
|
||||
@ -971,7 +1006,7 @@ inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void Literal::PopulateR2(
|
||||
void MutableLiteralBase::PopulateR2(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values) {
|
||||
CHECK(ShapeUtil::IsArray(shape()));
|
||||
CHECK_EQ(ShapeUtil::Rank(shape()), 2);
|
||||
@ -996,7 +1031,7 @@ void Literal::PopulateR2(
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void Literal::PopulateFromArray(const Array<NativeT>& values) {
|
||||
void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
|
||||
CHECK(ShapeUtil::IsArray(shape()));
|
||||
CHECK_EQ(shape().element_type(),
|
||||
primitive_util::NativeToPrimitiveType<NativeT>());
|
||||
@ -1009,23 +1044,23 @@ void Literal::PopulateFromArray(const Array<NativeT>& values) {
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
|
||||
void MutableLiteralBase::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
|
||||
PopulateFromArray(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
|
||||
void MutableLiteralBase::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
|
||||
PopulateFromArray(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
|
||||
void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
|
||||
PopulateFromArray(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void Literal::PopulateSparse(SparseIndexArray indices,
|
||||
tensorflow::gtl::ArraySlice<NativeT> values,
|
||||
void MutableLiteralBase::PopulateSparse(
|
||||
SparseIndexArray indices, tensorflow::gtl::ArraySlice<NativeT> values,
|
||||
bool sort) {
|
||||
CHECK(LayoutUtil::IsSparseArray(shape()));
|
||||
int rank = ShapeUtil::Rank(shape());
|
||||
@ -1049,7 +1084,8 @@ void Literal::PopulateSparse(SparseIndexArray indices,
|
||||
}
|
||||
|
||||
template <typename NativeT, typename FnType>
|
||||
Status Literal::PopulateInternal(const FnType& generator, bool parallel) {
|
||||
Status MutableLiteralBase::PopulateInternal(const FnType& generator,
|
||||
bool parallel) {
|
||||
const Shape& this_shape = shape();
|
||||
const int64 rank = ShapeUtil::Rank(this_shape);
|
||||
TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
|
||||
@ -1092,17 +1128,17 @@ Status Literal::PopulateInternal(const FnType& generator, bool parallel) {
|
||||
return Status::OK();
|
||||
}
|
||||
template <typename NativeT, typename FnType>
|
||||
Status Literal::Populate(const FnType& generator) {
|
||||
Status MutableLiteralBase::Populate(const FnType& generator) {
|
||||
return PopulateInternal<NativeT>(generator, /*parallel=*/false);
|
||||
}
|
||||
|
||||
template <typename NativeT, typename FnType>
|
||||
Status Literal::PopulateParallel(const FnType& generator) {
|
||||
Status MutableLiteralBase::PopulateParallel(const FnType& generator) {
|
||||
return PopulateInternal<NativeT>(generator, /*parallel=*/true);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void Literal::PopulateWithValue(NativeT value) {
|
||||
void MutableLiteralBase::PopulateWithValue(NativeT value) {
|
||||
CHECK(ShapeUtil::IsArray(shape()));
|
||||
CHECK_EQ(shape().element_type(),
|
||||
primitive_util::NativeToPrimitiveType<NativeT>());
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::strings::StrCat;
|
||||
|
@ -570,7 +570,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//third_party/eigen3",
|
||||
@ -613,6 +613,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:ptr_util",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -1384,6 +1385,18 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "while_loop_analysis",
|
||||
srcs = ["while_loop_analysis.cc"],
|
||||
hdrs = ["while_loop_analysis.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_evaluator",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "while_loop_simplifier",
|
||||
srcs = ["while_loop_simplifier.cc"],
|
||||
@ -1391,8 +1404,8 @@ cc_library(
|
||||
deps = [
|
||||
":call_inliner",
|
||||
":hlo",
|
||||
":hlo_evaluator",
|
||||
":hlo_pass",
|
||||
":while_loop_analysis",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
|
@ -1803,6 +1803,12 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
|
||||
}
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
|
||||
// TODO(b/112040122): Most of those optimizations can be done for multi-output
|
||||
// reduces.
|
||||
if (ShapeUtil::IsTuple(reduce->shape())) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto arg = reduce->mutable_operand(0);
|
||||
auto init_value = reduce->mutable_operand(1);
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
|
||||
|
@ -48,11 +48,6 @@ namespace xla {
|
||||
// compuation.
|
||||
using ObjectFileData = std::vector<char>;
|
||||
|
||||
// Contains the buffer sizes information needed to allocate buffers to execute
|
||||
// an ahead-of-time computation. Entries which contain -1 designate a parameter
|
||||
// which should be skipped over during allocation.
|
||||
using BufferSizes = std::vector<int64>;
|
||||
|
||||
// Abstract superclass describing the result of an ahead-of-time compilation.
|
||||
class AotCompilationResult {
|
||||
public:
|
||||
|
@ -54,12 +54,24 @@ cc_library(
|
||||
alwayslink = True, # Contains per-platform transfer manager registration
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "buffer_info_util",
|
||||
srcs = ["buffer_info_util.cc"],
|
||||
hdrs = ["buffer_info_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_compiler",
|
||||
srcs = ["cpu_compiler.cc"],
|
||||
hdrs = ["cpu_compiler.h"],
|
||||
deps = [
|
||||
":compiler_functor",
|
||||
":buffer_info_util",
|
||||
":conv_canonicalization",
|
||||
":cpu_copy_insertion",
|
||||
":cpu_executable",
|
||||
@ -73,6 +85,7 @@ cc_library(
|
||||
":ir_emitter",
|
||||
":parallel_task_assignment",
|
||||
":simple_orc_jit",
|
||||
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
|
57
tensorflow/compiler/xla/service/cpu/buffer_info_util.cc
Normal file
57
tensorflow/compiler/xla/service/cpu/buffer_info_util.cc
Normal file
@ -0,0 +1,57 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
|
||||
using BufferInfo = ::tensorflow::cpu_function_runtime::BufferInfo;
|
||||
|
||||
std::vector<BufferInfo> CreateBufferInfosFromBufferAssignment(
|
||||
const BufferAssignment& buffer_assignment) {
|
||||
std::vector<BufferInfo> buffer_infos;
|
||||
for (const BufferAllocation& allocation : buffer_assignment.Allocations()) {
|
||||
if (allocation.is_thread_local()) {
|
||||
buffer_infos.push_back(BufferInfo::MakeOnStackBuffer(allocation.size()));
|
||||
} else if (allocation.is_constant()) {
|
||||
buffer_infos.push_back(BufferInfo::MakeConstant(allocation.size()));
|
||||
} else if (allocation.is_entry_computation_parameter()) {
|
||||
buffer_infos.push_back(BufferInfo::MakeEntryParameter(
|
||||
/*size=*/allocation.size(),
|
||||
/*param_number=*/allocation.parameter_number()));
|
||||
} else {
|
||||
buffer_infos.push_back(BufferInfo::MakeTempBuffer(allocation.size()));
|
||||
}
|
||||
}
|
||||
return buffer_infos;
|
||||
}
|
||||
|
||||
std::vector<int32> CreateArgIndexTableFromBufferInfos(
|
||||
tensorflow::gtl::ArraySlice<BufferInfo> buffer_infos) {
|
||||
std::vector<int32> result;
|
||||
for (int64 i = 0; i < buffer_infos.size(); i++) {
|
||||
if (buffer_infos[i].is_entry_parameter()) {
|
||||
if (buffer_infos[i].entry_parameter_number() >= result.size()) {
|
||||
result.resize(buffer_infos[i].entry_parameter_number() + 1);
|
||||
}
|
||||
result[buffer_infos[i].entry_parameter_number()] = i;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
42
tensorflow/compiler/xla/service/cpu/buffer_info_util.h
Normal file
42
tensorflow/compiler/xla/service/cpu/buffer_info_util.h
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
// Creates and returns a list of BufferInfo instances containing relevant
|
||||
// information from `buffer_assignment`.
|
||||
std::vector<::tensorflow::cpu_function_runtime::BufferInfo>
|
||||
CreateBufferInfosFromBufferAssignment(
|
||||
const BufferAssignment& buffer_assignment);
|
||||
|
||||
// Creates and returns a table containing the mapping from entry computation
|
||||
// parameters to buffer allocation indices.
|
||||
//
|
||||
// If this function returns V then entry parameter i has buffer allocation index
|
||||
// V[i].
|
||||
std::vector<int32> CreateArgIndexTableFromBufferInfos(
|
||||
tensorflow::gtl::ArraySlice<::tensorflow::cpu_function_runtime::BufferInfo>
|
||||
buffer_infos);
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_
|
@ -50,6 +50,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
|
||||
@ -103,6 +104,7 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
using BufferInfo = ::tensorflow::cpu_function_runtime::BufferInfo;
|
||||
|
||||
CpuAotCompilationOptions::CpuAotCompilationOptions(
|
||||
string triple, string cpu_name, string features, string entry_point_name,
|
||||
@ -120,11 +122,11 @@ se::Platform::Id CpuAotCompilationOptions::PlatformId() const {
|
||||
}
|
||||
|
||||
CpuAotCompilationResult::CpuAotCompilationResult(
|
||||
ObjectFileData object_file_data, BufferSizes buffer_sizes,
|
||||
ObjectFileData object_file_data, std::vector<BufferInfo> buffer_infos,
|
||||
int64 result_buffer_index,
|
||||
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)
|
||||
: object_file_data_(std::move(object_file_data)),
|
||||
buffer_sizes_(std::move(buffer_sizes)),
|
||||
buffer_infos_(std::move(buffer_infos)),
|
||||
result_buffer_index_(result_buffer_index),
|
||||
hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {}
|
||||
|
||||
@ -838,39 +840,14 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
||||
ObjectFileData object_file_data(object_file->getBufferStart(),
|
||||
object_file->getBufferEnd());
|
||||
|
||||
BufferSizes buffer_sizes;
|
||||
for (const BufferAllocation& allocation : assignment->Allocations()) {
|
||||
// Callers don't need to allocate anything for thread-local temporary
|
||||
// buffers. They are lowered to allocas.
|
||||
if (allocation.is_thread_local()) {
|
||||
buffer_sizes.push_back(-1);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Callers don't need to allocate anything for constant buffers. They are
|
||||
// lowered to globals.
|
||||
if (allocation.is_constant()) {
|
||||
buffer_sizes.push_back(-1);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Callers don't need to allocate anything for entry computation buffers,
|
||||
// but they do need to stash the pointer to the entry computation buffer
|
||||
// in the temp buffer table. See the comment on
|
||||
// XlaCompiledCpuFunction::StaticData::temp_sizes.
|
||||
if (allocation.is_entry_computation_parameter()) {
|
||||
buffer_sizes.push_back(-allocation.parameter_number() - 2);
|
||||
continue;
|
||||
}
|
||||
|
||||
buffer_sizes.push_back(allocation.size());
|
||||
}
|
||||
std::vector<BufferInfo> buffer_infos =
|
||||
CreateBufferInfosFromBufferAssignment(*assignment);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
|
||||
assignment->GetUniqueTopLevelOutputSlice());
|
||||
|
||||
results.emplace_back(MakeUnique<CpuAotCompilationResult>(
|
||||
std::move(object_file_data), std::move(buffer_sizes),
|
||||
std::move(object_file_data), std::move(buffer_infos),
|
||||
result_slice.index(), std::move(hlo_profile_printer_data)));
|
||||
}
|
||||
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
|
||||
@ -78,7 +79,8 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
|
||||
class CpuAotCompilationResult : public AotCompilationResult {
|
||||
public:
|
||||
CpuAotCompilationResult(
|
||||
ObjectFileData object_file_data, BufferSizes buffer_sizes,
|
||||
ObjectFileData object_file_data,
|
||||
std::vector<::tensorflow::cpu_function_runtime::BufferInfo> buffer_infos,
|
||||
int64 result_buffer_index,
|
||||
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data);
|
||||
~CpuAotCompilationResult();
|
||||
@ -88,17 +90,20 @@ class CpuAotCompilationResult : public AotCompilationResult {
|
||||
}
|
||||
|
||||
const ObjectFileData& object_file_data() const { return object_file_data_; }
|
||||
const BufferSizes& buffer_sizes() const { return buffer_sizes_; }
|
||||
const std::vector<::tensorflow::cpu_function_runtime::BufferInfo>&
|
||||
buffer_infos() const {
|
||||
return buffer_infos_;
|
||||
}
|
||||
int64 result_buffer_index() const { return result_buffer_index_; }
|
||||
|
||||
private:
|
||||
// Contains the compiled computation: an object file.
|
||||
const ObjectFileData object_file_data_;
|
||||
|
||||
// The list of buffer sizes which should be allocated in order to execute the
|
||||
// compiled computation. These buffers are used for temporary buffers used
|
||||
// ephemerally during computation as well as the output result.
|
||||
const BufferSizes buffer_sizes_;
|
||||
// A list of BufferInfo objects describing the buffers used by the XLA
|
||||
// computation.
|
||||
const std::vector<::tensorflow::cpu_function_runtime::BufferInfo>
|
||||
buffer_infos_;
|
||||
|
||||
// Contains which buffer index into |buffer_sizes| was designated to the
|
||||
// result of the computation. This buffer should be passed into the output
|
||||
|
@ -173,7 +173,7 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor,
|
||||
|
||||
Status CpuTransferManager::TransferLiteralFromOutfeed(
|
||||
se::StreamExecutor* executor, const Shape& literal_shape,
|
||||
Literal* literal) {
|
||||
MutableBorrowingLiteral literal) {
|
||||
if (!ShapeUtil::IsTuple(literal_shape)) {
|
||||
int64 size = GetByteSizeRequirement(literal_shape);
|
||||
// Note: OSS build didn't like implicit conversion from
|
||||
@ -181,18 +181,16 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions(
|
||||
tensorflow::bit_cast<const int64*>(literal_shape.dimensions().data()),
|
||||
literal_shape.dimensions().size());
|
||||
*literal = std::move(*LiteralUtil::CreateFromDimensions(
|
||||
literal_shape.element_type(), dimensions));
|
||||
TF_ASSIGN_OR_RETURN(Shape received_shape,
|
||||
TransferArrayBufferFromOutfeed(
|
||||
executor, literal->untyped_data(), size));
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal->shape()))
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape received_shape,
|
||||
TransferArrayBufferFromOutfeed(executor, literal.untyped_data(), size));
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal.shape()))
|
||||
<< "Shape received from outfeed "
|
||||
<< ShapeUtil::HumanString(received_shape)
|
||||
<< " did not match the shape that was requested for outfeed: "
|
||||
<< ShapeUtil::HumanString(literal_shape);
|
||||
TF_RET_CHECK(size == GetByteSizeRequirement(received_shape));
|
||||
*literal->mutable_shape_do_not_use() = received_shape;
|
||||
*literal.mutable_shape_do_not_use() = received_shape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -201,22 +199,12 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
|
||||
"Nested tuple outfeeds are not yet implemented on CPU.");
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<Literal>> elements;
|
||||
std::vector<std::pair<void*, int64>> buffer_data;
|
||||
for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) {
|
||||
const Shape& tuple_element_shape =
|
||||
ShapeUtil::GetTupleElementShape(literal_shape, i);
|
||||
// Note: OSS build didn't like implicit conversion from
|
||||
// literal_shape.dimensions() to the array slice on 2017-07-10.
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions(
|
||||
tensorflow::bit_cast<const int64*>(
|
||||
tuple_element_shape.dimensions().data()),
|
||||
tuple_element_shape.dimensions().size());
|
||||
auto empty = LiteralUtil::CreateFromDimensions(
|
||||
tuple_element_shape.element_type(), dimensions);
|
||||
int64 size = GetByteSizeRequirement(tuple_element_shape);
|
||||
buffer_data.push_back({empty->untyped_data(), size});
|
||||
elements.push_back(std::move(empty));
|
||||
buffer_data.push_back({literal.untyped_data({i}), size});
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Shape received_shape,
|
||||
@ -230,11 +218,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
|
||||
TF_RET_CHECK(GetByteSizeRequirement(literal_shape) ==
|
||||
GetByteSizeRequirement(received_shape));
|
||||
|
||||
for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) {
|
||||
*elements[i]->mutable_shape_do_not_use() = received_shape.tuple_shapes(i);
|
||||
}
|
||||
*literal = std::move(*LiteralUtil::MakeTupleOwned(std::move(elements)));
|
||||
TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape));
|
||||
TF_RET_CHECK(ShapeUtil::Equal(literal.shape(), literal_shape));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h"
|
||||
#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
@ -41,7 +42,7 @@ class CpuTransferManager : public GenericTransferManager {
|
||||
const LiteralSlice& literal) override;
|
||||
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
|
||||
const Shape& literal_shape,
|
||||
Literal* literal) override;
|
||||
MutableBorrowingLiteral literal) override;
|
||||
|
||||
private:
|
||||
Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,
|
||||
|
@ -30,47 +30,6 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
|
||||
StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
|
||||
const HloInstruction* op, llvm::Value* operand_value) const {
|
||||
switch (op->opcode()) {
|
||||
case HloOpcode::kTanh: {
|
||||
PrimitiveType element_type = op->shape().element_type();
|
||||
bool cast_result_to_fp16 = false;
|
||||
string function_name;
|
||||
switch (element_type) {
|
||||
case F16:
|
||||
cast_result_to_fp16 = true;
|
||||
operand_value = b_->CreateFPCast(operand_value, b_->getFloatTy());
|
||||
TF_FALLTHROUGH_INTENDED;
|
||||
case F32:
|
||||
function_name = "tanhf";
|
||||
break;
|
||||
case F64:
|
||||
function_name = "tanh";
|
||||
break;
|
||||
default:
|
||||
return Unimplemented("tanh");
|
||||
}
|
||||
// Create a function declaration.
|
||||
llvm::Function* function =
|
||||
llvm::cast<llvm::Function>(module_->getOrInsertFunction(
|
||||
llvm_ir::AsStringRef(function_name), operand_value->getType(),
|
||||
operand_value->getType()));
|
||||
function->setCallingConv(llvm::CallingConv::C);
|
||||
function->setDoesNotThrow();
|
||||
function->setDoesNotAccessMemory();
|
||||
// Create an instruction to call the function.
|
||||
llvm::Value* result = b_->CreateCall(function, operand_value);
|
||||
if (cast_result_to_fp16) {
|
||||
result = b_->CreateFPCast(result, b_->getHalfTy());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
default:
|
||||
return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value);
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
|
||||
PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const {
|
||||
string function_name;
|
||||
@ -106,6 +65,39 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
|
||||
return result;
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(
|
||||
PrimitiveType prim_type, llvm::Value* value) const {
|
||||
bool cast_result_to_fp16 = false;
|
||||
string function_name;
|
||||
switch (prim_type) {
|
||||
case F16:
|
||||
cast_result_to_fp16 = true;
|
||||
value = b_->CreateFPCast(value, b_->getFloatTy());
|
||||
TF_FALLTHROUGH_INTENDED;
|
||||
case F32:
|
||||
function_name = "tanhf";
|
||||
break;
|
||||
case F64:
|
||||
function_name = "tanh";
|
||||
break;
|
||||
default:
|
||||
return Unimplemented("tanh");
|
||||
}
|
||||
// Create a function declaration.
|
||||
llvm::Function* function = llvm::cast<llvm::Function>(
|
||||
module_->getOrInsertFunction(llvm_ir::AsStringRef(function_name),
|
||||
value->getType(), value->getType()));
|
||||
function->setCallingConv(llvm::CallingConv::C);
|
||||
function->setDoesNotThrow();
|
||||
function->setDoesNotAccessMemory();
|
||||
// Create an instruction to call the function.
|
||||
llvm::Value* result = b_->CreateCall(function, value);
|
||||
if (cast_result_to_fp16) {
|
||||
result = b_->CreateFPCast(result, b_->getHalfTy());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
|
||||
const HloInstruction* hlo,
|
||||
const HloToElementGeneratorMap& operand_to_generator) const {
|
||||
|
@ -39,10 +39,10 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
|
||||
const HloToElementGeneratorMap& operand_to_generator) const override;
|
||||
|
||||
protected:
|
||||
StatusOr<llvm::Value*> EmitFloatUnaryOp(
|
||||
const HloInstruction* op, llvm::Value* operand_value) const override;
|
||||
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
|
||||
llvm::Value* rhs) const override;
|
||||
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
|
||||
llvm::Value* value) const override;
|
||||
|
||||
IrEmitter* ir_emitter_;
|
||||
};
|
||||
|
@ -1756,6 +1756,10 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleReduce(HloInstruction* reduce) {
|
||||
// TODO(b/112040122): Support variadic reduce.
|
||||
if (!ShapeUtil::IsArray(reduce->shape())) {
|
||||
return Unimplemented("Variadic reduce is not supported on CPU");
|
||||
}
|
||||
auto arg = reduce->mutable_operand(0);
|
||||
auto init_value = reduce->mutable_operand(1);
|
||||
gtl::ArraySlice<int64> dimensions(reduce->dimensions());
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
|
||||
#if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY)
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
|
||||
#include "third_party/intel_mkl_ml/include/mkl_cblas.h"
|
||||
#include "third_party/intel_mkl_ml/include/mkl_service.h"
|
||||
|
@ -106,6 +106,7 @@ class DfsHloVisitorBase {
|
||||
virtual Status HandleConvolution(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleFft(HloInstructionPtr fft) = 0;
|
||||
virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleCompare(HloInstructionPtr hlo) {
|
||||
return HandleElementwiseBinary(hlo);
|
||||
}
|
||||
|
@ -94,6 +94,9 @@ class DfsHloVisitorWithDefaultBase
|
||||
Status HandleCrossReplicaSum(HloInstructionPtr crs) override {
|
||||
return DefaultAction(crs);
|
||||
}
|
||||
Status HandleAllToAll(HloInstructionPtr crs) override {
|
||||
return DefaultAction(crs);
|
||||
}
|
||||
Status HandleRng(HloInstructionPtr random) override {
|
||||
return DefaultAction(random);
|
||||
}
|
||||
|
@ -431,6 +431,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
|
||||
return EmitCos(op->shape().element_type(), operand_value);
|
||||
case HloOpcode::kSin:
|
||||
return EmitSin(op->shape().element_type(), operand_value);
|
||||
case HloOpcode::kTanh:
|
||||
return EmitTanh(op->shape().element_type(), operand_value);
|
||||
case HloOpcode::kFloor:
|
||||
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor,
|
||||
{operand_value},
|
||||
@ -1060,6 +1062,11 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
|
||||
return Unimplemented("atan2");
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
|
||||
llvm::Value* value) const {
|
||||
return Unimplemented("tanh");
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
|
||||
const HloInstruction* hlo, llvm::Value* x) const {
|
||||
if (hlo->operand(0)->shape().element_type() != F32) {
|
||||
@ -1239,13 +1246,23 @@ StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
|
||||
// Convert raw integer to float in range [0, 1) if the element is a float.
|
||||
llvm::Value* elem_value = raw_value;
|
||||
if (elem_ir_ty->isFloatingPointTy()) {
|
||||
elem_value = b_->CreateUIToFP(elem_value, elem_ir_ty);
|
||||
unsigned raw_value_size_in_bits = raw_value_ty->getPrimitiveSizeInBits();
|
||||
CHECK(raw_value_size_in_bits == 32 || raw_value_size_in_bits == 64);
|
||||
// Perform the division using the float type with the same number of bits
|
||||
// as the raw value to avoid overflow.
|
||||
if (raw_value_size_in_bits == 32) {
|
||||
elem_value = b_->CreateUIToFP(elem_value, b_->getFloatTy());
|
||||
elem_value = b_->CreateFDiv(
|
||||
elem_value,
|
||||
llvm::ConstantFP::get(elem_ir_ty,
|
||||
raw_value_size_in_bits == 64 ? 0x1p64 : 0x1p32));
|
||||
elem_value, llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32)));
|
||||
} else {
|
||||
elem_value = b_->CreateUIToFP(elem_value, b_->getDoubleTy());
|
||||
elem_value = b_->CreateFDiv(
|
||||
elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64)));
|
||||
}
|
||||
|
||||
if (elem_ir_ty != elem_value->getType()) {
|
||||
elem_value = b_->CreateFPTrunc(elem_value, elem_ir_ty);
|
||||
}
|
||||
}
|
||||
|
||||
// Convert the value for the requested distribution.
|
||||
@ -1302,6 +1319,7 @@ int32 GetNumberOfElementsPerPhiloxRngSample(PrimitiveType elem_prim_ty) {
|
||||
case F16:
|
||||
return 4;
|
||||
case U64:
|
||||
case S64:
|
||||
case F64:
|
||||
return 2;
|
||||
default:
|
||||
|
@ -122,6 +122,9 @@ class ElementalIrEmitter {
|
||||
llvm::Value* lhs,
|
||||
llvm::Value* rhs) const;
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
|
||||
llvm::Value* value) const;
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo,
|
||||
llvm::Value* x) const;
|
||||
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -60,17 +59,19 @@ Status GenericTransferManager::WriteSingleTupleIndexTable(
|
||||
|
||||
void GenericTransferManager::TransferLiteralFromDevice(
|
||||
se::Stream* stream, const ShapedBuffer& device_buffer,
|
||||
std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) {
|
||||
MutableBorrowingLiteral literal, std::function<void(Status)> done) {
|
||||
Status status = stream->BlockHostUntilDone();
|
||||
if (!status.ok()) {
|
||||
return done(status);
|
||||
}
|
||||
done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer));
|
||||
|
||||
done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer,
|
||||
literal));
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>>
|
||||
GenericTransferManager::TransferLiteralFromDeviceInternal(
|
||||
se::StreamExecutor* executor, const ShapedBuffer& device_buffer) {
|
||||
Status GenericTransferManager::TransferLiteralFromDeviceInternal(
|
||||
se::StreamExecutor* executor, const ShapedBuffer& device_buffer,
|
||||
MutableBorrowingLiteral literal) {
|
||||
VLOG(2) << "transferring literal from device ordinal "
|
||||
<< executor->device_ordinal() << "; device buffer: " << device_buffer;
|
||||
TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal());
|
||||
@ -80,9 +81,6 @@ GenericTransferManager::TransferLiteralFromDeviceInternal(
|
||||
TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
|
||||
device_buffer.on_host_shape()));
|
||||
|
||||
std::unique_ptr<Literal> literal =
|
||||
Literal::CreateFromShape(device_buffer.on_host_shape());
|
||||
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
|
||||
device_buffer.on_host_shape(),
|
||||
[&](const Shape& subshape, const ShapeIndex& index) -> Status {
|
||||
@ -91,12 +89,12 @@ GenericTransferManager::TransferLiteralFromDeviceInternal(
|
||||
/*source=*/device_buffer.buffer(index),
|
||||
/*size=*/GetByteSizeRequirement(subshape),
|
||||
/*destination=*/
|
||||
literal->untyped_data(index)));
|
||||
literal.untyped_data(index)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}));
|
||||
return std::move(literal);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GenericTransferManager::TransferLiteralToDeviceAsync(
|
||||
@ -160,7 +158,7 @@ Status GenericTransferManager::TransferLiteralToInfeed(
|
||||
|
||||
Status GenericTransferManager::TransferLiteralFromOutfeed(
|
||||
se::StreamExecutor* executor, const Shape& literal_shape,
|
||||
Literal* literal) {
|
||||
MutableBorrowingLiteral literal) {
|
||||
return Unimplemented("Generic transfer from Outfeed");
|
||||
}
|
||||
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
@ -41,9 +40,10 @@ class GenericTransferManager : public TransferManager {
|
||||
|
||||
se::Platform::Id PlatformId() const override;
|
||||
|
||||
void TransferLiteralFromDevice(
|
||||
se::Stream* stream, const ShapedBuffer& device_buffer,
|
||||
std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) override;
|
||||
void TransferLiteralFromDevice(se::Stream* stream,
|
||||
const ShapedBuffer& device_buffer,
|
||||
MutableBorrowingLiteral literal,
|
||||
std::function<void(Status)> done) override;
|
||||
|
||||
Status TransferLiteralToDeviceAsync(
|
||||
se::Stream* stream, const LiteralSlice& literal,
|
||||
@ -53,7 +53,7 @@ class GenericTransferManager : public TransferManager {
|
||||
const LiteralSlice& literal) override;
|
||||
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
|
||||
const Shape& literal_shape,
|
||||
Literal* literal) override;
|
||||
MutableBorrowingLiteral literal) override;
|
||||
|
||||
Status ResetDevices(
|
||||
tensorflow::gtl::ArraySlice<se::StreamExecutor*> executors) override;
|
||||
@ -67,8 +67,9 @@ class GenericTransferManager : public TransferManager {
|
||||
const Shape& shape, se::DeviceMemoryBase* region) override;
|
||||
|
||||
private:
|
||||
StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDeviceInternal(
|
||||
se::StreamExecutor* executor, const ShapedBuffer& device_buffer);
|
||||
Status TransferLiteralFromDeviceInternal(se::StreamExecutor* executor,
|
||||
const ShapedBuffer& device_buffer,
|
||||
MutableBorrowingLiteral literal);
|
||||
|
||||
// The platform this transfer manager targets.
|
||||
const se::Platform::Id platform_id_;
|
||||
|
@ -153,7 +153,6 @@ cc_library(
|
||||
":ir_emission_utils",
|
||||
":parallel_loop_emitter",
|
||||
":partition_assignment",
|
||||
":while_transformer",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -166,6 +165,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:name_uniquer",
|
||||
"//tensorflow/compiler/xla/service:while_loop_analysis",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
|
||||
@ -655,7 +655,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:transpose_folding",
|
||||
"//tensorflow/compiler/xla/service:tuple_simplifier",
|
||||
"//tensorflow/compiler/xla/service:while_loop_constant_sinking",
|
||||
"//tensorflow/compiler/xla/service:while_loop_invariant_code_motion",
|
||||
"//tensorflow/compiler/xla/service:while_loop_simplifier",
|
||||
"//tensorflow/compiler/xla/service:zero_sized_hlo_elimination",
|
||||
"//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter",
|
||||
@ -788,32 +787,17 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "while_transformer",
|
||||
srcs = ["while_transformer.cc"],
|
||||
hdrs = ["while_transformer.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "while_transformer_test",
|
||||
srcs = ["while_transformer_test.cc"],
|
||||
deps = [
|
||||
":instruction_fusion",
|
||||
":while_transformer",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla/service:copy_insertion",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/service:while_loop_analysis",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/optional.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
@ -137,6 +138,28 @@ string NumBytesToString(int64 bytes) {
|
||||
tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)");
|
||||
}
|
||||
|
||||
// Acquires a process-global lock on the device pointed to by the given
|
||||
// StreamExecutor.
|
||||
//
|
||||
// This is used to prevent other XLA instances from trying to autotune on this
|
||||
// device while we're using it.
|
||||
tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
|
||||
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
|
||||
// se::Platform*s are global singletons guaranteed to live forever.
|
||||
static auto* mutexes =
|
||||
new std::map<std::pair<const se::Platform*, /*device_ordinal*/ int64>,
|
||||
tensorflow::mutex>();
|
||||
|
||||
tensorflow::mutex_lock global_lock(mu);
|
||||
auto it = mutexes
|
||||
->emplace(std::piecewise_construct,
|
||||
std::make_tuple(stream_exec->platform(),
|
||||
stream_exec->device_ordinal()),
|
||||
std::make_tuple())
|
||||
.first;
|
||||
return tensorflow::mutex_lock{it->second};
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// We could have caching here so that we don't redo this work for two identical
|
||||
@ -155,6 +178,13 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
|
||||
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
|
||||
const Shape& output_shape, const Window& window,
|
||||
const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) {
|
||||
// Don't run this function concurrently on the same GPU.
|
||||
//
|
||||
// This is a bit of a hack and doesn't protect us against arbitrary concurrent
|
||||
// use of a GPU, but it's sufficient to let us compile two HLO modules
|
||||
// concurrently and then run them sequentially.
|
||||
tensorflow::mutex_lock lock = LockGpu(stream_exec_);
|
||||
|
||||
// Create a stream for us to do our work on.
|
||||
se::Stream stream{stream_exec_};
|
||||
stream.Init();
|
||||
|
@ -272,27 +272,18 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
|
||||
prim_type);
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatUnaryOp(
|
||||
const HloInstruction* op, llvm::Value* operand_value) const {
|
||||
PrimitiveType input_type = op->operand(0)->shape().element_type();
|
||||
PrimitiveType output_type = op->shape().element_type();
|
||||
switch (op->opcode()) {
|
||||
case HloOpcode::kTanh:
|
||||
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(
|
||||
PrimitiveType prim_type, llvm::Value* value) const {
|
||||
// If we don't care much about precision, emit a fast approximation of
|
||||
// tanh.
|
||||
if (hlo_module_config_.debug_options().xla_enable_fast_math()) {
|
||||
// Upcast F16 to F32 if necessary.
|
||||
llvm::Type* type =
|
||||
input_type == F16 ? b_->getFloatTy() : operand_value->getType();
|
||||
llvm::Value* input = b_->CreateFPCast(operand_value, type);
|
||||
llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
|
||||
llvm::Value* input = b_->CreateFPCast(value, type);
|
||||
llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
|
||||
return b_->CreateFPCast(fast_tanh, operand_value->getType());
|
||||
}
|
||||
return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type},
|
||||
output_type);
|
||||
default:
|
||||
return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value);
|
||||
return b_->CreateFPCast(fast_tanh, value->getType());
|
||||
}
|
||||
return EmitLibdeviceMathCall("__nv_tanh", {value}, {prim_type}, prim_type);
|
||||
}
|
||||
|
||||
llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
|
||||
@ -445,6 +436,8 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
|
||||
return b_->CreateLoad(accum_ptr);
|
||||
};
|
||||
case HloOpcode::kReduce:
|
||||
// TODO(b/112040122): This should be supported.
|
||||
CHECK_EQ(hlo->operand_count(), 2) << "Did not expect variadic reduce";
|
||||
return [=, &operand_to_generator](
|
||||
const IrArray::Index& output_index) -> StatusOr<llvm::Value*> {
|
||||
const HloInstruction* operand = hlo->operand(0);
|
||||
|
@ -51,9 +51,6 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
|
||||
const HloToElementGeneratorMap& operand_to_generator) const override;
|
||||
|
||||
protected:
|
||||
StatusOr<llvm::Value*> EmitFloatUnaryOp(
|
||||
const HloInstruction* op, llvm::Value* operand_value) const override;
|
||||
|
||||
StatusOr<llvm::Value*> EmitFloatBinaryOp(
|
||||
const HloInstruction* op, llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value) const override;
|
||||
@ -85,6 +82,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
|
||||
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
|
||||
llvm::Value* rhs) const override;
|
||||
|
||||
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
|
||||
llvm::Value* value) const override;
|
||||
|
||||
llvm::Value* EmitThreadId() const override;
|
||||
|
||||
private:
|
||||
|
@ -52,12 +52,12 @@ class GemmThunk : public Thunk {
|
||||
se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
// Returns true if we'll perform autotuning if run on the given stream. If
|
||||
// so, we want the GPU to be quiescent during autotuning, so as not to
|
||||
// introduce noise in our results.
|
||||
bool ShouldHaltAllActivityBeforeRunning(se::Stream* stream) override {
|
||||
return autotune_results_.count(
|
||||
stream->parent()->GetDeviceDescription().name()) != 0;
|
||||
bool WillAutotuneKernel(se::Stream* stream) override {
|
||||
// We will autotune this kernel if we don't already have a autotune result
|
||||
// for the stream device.
|
||||
return autotune_results_.find(
|
||||
stream->parent()->GetDeviceDescription().name()) ==
|
||||
autotune_results_.end();
|
||||
}
|
||||
|
||||
private:
|
||||
@ -75,6 +75,8 @@ class GemmThunk : public Thunk {
|
||||
// results. The map's value is the best algorithm we've found for this thunk
|
||||
// on this device, or an error if none of the algorithms worked and we should
|
||||
// use the regular gemm without an algorithm.
|
||||
//
|
||||
// TODO(b/112415150): Make this thread safe.
|
||||
std::unordered_map<string, StatusOr<se::blas::AlgorithmType>>
|
||||
autotune_results_;
|
||||
};
|
||||
|
@ -131,9 +131,10 @@ Status GpuExecutable::ExecuteThunks(
|
||||
stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get());
|
||||
}
|
||||
|
||||
// If this thunk requests it, wait for all currently-executing thunks to
|
||||
// finish. This is useful e.g. if the thunk is about to perform autotuning.
|
||||
if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) {
|
||||
// If this thunk is about to autotune then wait for all currently executing
|
||||
// thunks to finish. This reduces noise and thus the probability of
|
||||
// choosing a suboptimal algorithm.
|
||||
if (thunk->WillAutotuneKernel(stream)) {
|
||||
TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone());
|
||||
}
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user