Add basic gpu support in TF-TFRT integration. When compiling with --config=cuda, tfrt will automatically detect the gpu device.

Right now it only supports one gpu.

PiperOrigin-RevId: 309326059
Change-Id: Iab0f7bbed965d72c5b88cc3eef6d01b97b253d28
This commit is contained in:
Xiao Yu 2020-04-30 16:53:37 -07:00 committed by TensorFlower Gardener
parent 0024b55d63
commit 9c3f0435ad
3 changed files with 10 additions and 2 deletions

View File

@ -16,6 +16,7 @@ load(
"//tensorflow/core/platform:build_config_root.bzl",
"tf_cuda_tests_tags",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
package(
licenses = ["notice"], # Apache 2.0

View File

@ -694,8 +694,13 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (opts->use_tfrt) {
#ifdef PLATFORM_GOOGLE
status->status = tensorflow::Status::OK();
return tensorflow::wrap(new tfrt::ContextInterface());
tfrt::SmallVector<std::string, 4> op_handler_chains;
tfrt::SmallVector<tensorflow::DeviceAttributes, 4> device_attributes;
status->status = tfrt::ListOpHandlerChains(
opts->session_options.options, &op_handler_chains, &device_attributes);
if (!status->status.ok()) return nullptr;
return tensorflow::wrap(
new tfrt::ContextInterface(op_handler_chains, device_attributes));
#else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
return nullptr;

View File

@ -111,6 +111,8 @@ class AttrBuilder {
return *this;
}
size_t NumAttributes() const { return encoded_attrs_.size(); }
AttrBuilder& Set(StringPiece attr_name, const AttrValue& value) {
AddAttrIfNotPresent(attr_name, value);
cached_cache_key_ = absl::nullopt;