Use CapturableResource._resource_device instead of creating a new member for

the device.

PiperOrigin-RevId: 256273685
This commit is contained in:
Guangda Lai 2019-07-02 17:36:26 -07:00 committed by TensorFlower Gardener
parent 05a4e4bf8c
commit 9189deb241

View File

@ -828,18 +828,16 @@ class TRTEngineResource(tracking.TrackableResource):
filename, filename,
maximum_cached_engines, maximum_cached_engines,
device="GPU"): device="GPU"):
super( super(TRTEngineResource, self).__init__(
TRTEngineResource, device=device, deleter=TRTEngineResourceDeleter(resource_name, device))
self).__init__(deleter=TRTEngineResourceDeleter(resource_name, device))
self._resource_name = resource_name self._resource_name = resource_name
# Track the serialized engine file in the SavedModel. # Track the serialized engine file in the SavedModel.
self._filename = self._track_trackable( self._filename = self._track_trackable(
tracking.TrackableAsset(filename), "_serialized_trt_engine_filename") tracking.TrackableAsset(filename), "_serialized_trt_engine_filename")
self._maximum_cached_engines = maximum_cached_engines self._maximum_cached_engines = maximum_cached_engines
self._device = device
def _create_resource(self): def _create_resource(self):
return _get_resource_handle(self._resource_name, self._device) return _get_resource_handle(self._resource_name, self._resource_device)
def _initialize(self): def _initialize(self):
gen_trt_ops.populate_trt_engine_cache( gen_trt_ops.populate_trt_engine_cache(