Add basic gpu support in TF-TFRT integration. When compiling with --config=cuda, tfrt will automatically detect the gpu device.
Right now it only supports one gpu. PiperOrigin-RevId: 309326059 Change-Id: Iab0f7bbed965d72c5b88cc3eef6d01b97b253d28
This commit is contained in:
parent
0024b55d63
commit
9c3f0435ad
@ -16,6 +16,7 @@ load(
|
|||||||
"//tensorflow/core/platform:build_config_root.bzl",
|
"//tensorflow/core/platform:build_config_root.bzl",
|
||||||
"tf_cuda_tests_tags",
|
"tf_cuda_tests_tags",
|
||||||
)
|
)
|
||||||
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
|||||||
@ -694,8 +694,13 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
|||||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||||
if (opts->use_tfrt) {
|
if (opts->use_tfrt) {
|
||||||
#ifdef PLATFORM_GOOGLE
|
#ifdef PLATFORM_GOOGLE
|
||||||
status->status = tensorflow::Status::OK();
|
tfrt::SmallVector<std::string, 4> op_handler_chains;
|
||||||
return tensorflow::wrap(new tfrt::ContextInterface());
|
tfrt::SmallVector<tensorflow::DeviceAttributes, 4> device_attributes;
|
||||||
|
status->status = tfrt::ListOpHandlerChains(
|
||||||
|
opts->session_options.options, &op_handler_chains, &device_attributes);
|
||||||
|
if (!status->status.ok()) return nullptr;
|
||||||
|
return tensorflow::wrap(
|
||||||
|
new tfrt::ContextInterface(op_handler_chains, device_attributes));
|
||||||
#else
|
#else
|
||||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|||||||
@ -111,6 +111,8 @@ class AttrBuilder {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t NumAttributes() const { return encoded_attrs_.size(); }
|
||||||
|
|
||||||
AttrBuilder& Set(StringPiece attr_name, const AttrValue& value) {
|
AttrBuilder& Set(StringPiece attr_name, const AttrValue& value) {
|
||||||
AddAttrIfNotPresent(attr_name, value);
|
AddAttrIfNotPresent(attr_name, value);
|
||||||
cached_cache_key_ = absl::nullopt;
|
cached_cache_key_ = absl::nullopt;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user