Add basic gpu support in TF-TFRT integration. When compiling with --config=cuda, tfrt will automatically detect the gpu device.
Right now it only supports one gpu. PiperOrigin-RevId: 309326059 Change-Id: Iab0f7bbed965d72c5b88cc3eef6d01b97b253d28
This commit is contained in:
parent
0024b55d63
commit
9c3f0435ad
@ -16,6 +16,7 @@ load(
|
||||
"//tensorflow/core/platform:build_config_root.bzl",
|
||||
"tf_cuda_tests_tags",
|
||||
)
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
|
||||
@ -694,8 +694,13 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
if (opts->use_tfrt) {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
status->status = tensorflow::Status::OK();
|
||||
return tensorflow::wrap(new tfrt::ContextInterface());
|
||||
tfrt::SmallVector<std::string, 4> op_handler_chains;
|
||||
tfrt::SmallVector<tensorflow::DeviceAttributes, 4> device_attributes;
|
||||
status->status = tfrt::ListOpHandlerChains(
|
||||
opts->session_options.options, &op_handler_chains, &device_attributes);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
return tensorflow::wrap(
|
||||
new tfrt::ContextInterface(op_handler_chains, device_attributes));
|
||||
#else
|
||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||
return nullptr;
|
||||
|
||||
@ -111,6 +111,8 @@ class AttrBuilder {
|
||||
return *this;
|
||||
}
|
||||
|
||||
size_t NumAttributes() const { return encoded_attrs_.size(); }
|
||||
|
||||
AttrBuilder& Set(StringPiece attr_name, const AttrValue& value) {
|
||||
AddAttrIfNotPresent(attr_name, value);
|
||||
cached_cache_key_ = absl::nullopt;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user