Use CapturableResource._resource_device instead of creating a new member for
the device. PiperOrigin-RevId: 256273685
This commit is contained in:
parent
05a4e4bf8c
commit
9189deb241
@ -828,18 +828,16 @@ class TRTEngineResource(tracking.TrackableResource):
|
|||||||
filename,
|
filename,
|
||||||
maximum_cached_engines,
|
maximum_cached_engines,
|
||||||
device="GPU"):
|
device="GPU"):
|
||||||
super(
|
super(TRTEngineResource, self).__init__(
|
||||||
TRTEngineResource,
|
device=device, deleter=TRTEngineResourceDeleter(resource_name, device))
|
||||||
self).__init__(deleter=TRTEngineResourceDeleter(resource_name, device))
|
|
||||||
self._resource_name = resource_name
|
self._resource_name = resource_name
|
||||||
# Track the serialized engine file in the SavedModel.
|
# Track the serialized engine file in the SavedModel.
|
||||||
self._filename = self._track_trackable(
|
self._filename = self._track_trackable(
|
||||||
tracking.TrackableAsset(filename), "_serialized_trt_engine_filename")
|
tracking.TrackableAsset(filename), "_serialized_trt_engine_filename")
|
||||||
self._maximum_cached_engines = maximum_cached_engines
|
self._maximum_cached_engines = maximum_cached_engines
|
||||||
self._device = device
|
|
||||||
|
|
||||||
def _create_resource(self):
|
def _create_resource(self):
|
||||||
return _get_resource_handle(self._resource_name, self._device)
|
return _get_resource_handle(self._resource_name, self._resource_device)
|
||||||
|
|
||||||
def _initialize(self):
|
def _initialize(self):
|
||||||
gen_trt_ops.populate_trt_engine_cache(
|
gen_trt_ops.populate_trt_engine_cache(
|
||||||
|
Loading…
Reference in New Issue
Block a user