Update the TRT resource op names.
PiperOrigin-RevId: 261197551
This commit is contained in:
parent
c38d739120
commit
ebe2e50f36
tensorflow
compiler/tf2tensorrt
kernels
get_calibration_data_op.cctrt_engine_op.cctrt_engine_op_test.cctrt_engine_resource_ops.cctrt_engine_resource_ops_test.cc
ops
utils
core/api_def
python/compiler/tensorrt
@ -44,7 +44,7 @@ class GetCalibrationDataOp : public OpKernel {
|
||||
// Get the resource.
|
||||
TRTEngineCacheResource* resource = nullptr;
|
||||
OP_REQUIRES_OK(context, context->resource_manager()->Lookup(
|
||||
std::string(kCacheContainerName), resource_name,
|
||||
std::string(kTfTrtContainerName), resource_name,
|
||||
&resource));
|
||||
core::ScopedUnref sc(resource);
|
||||
|
||||
|
@ -618,7 +618,7 @@ Status TRTEngineOp::GetEngineCacheResource(OpKernelContext* ctx,
|
||||
|
||||
// Get engine cache.
|
||||
return ctx->resource_manager()->LookupOrCreate(
|
||||
std::string(kCacheContainerName), std::string(resource_name), cache_res,
|
||||
std::string(kTfTrtContainerName), std::string(resource_name), cache_res,
|
||||
{[this, ctx](TRTEngineCacheResource** cr) -> Status {
|
||||
*cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_);
|
||||
if (calibration_mode_) {
|
||||
|
@ -126,8 +126,8 @@ TEST_F(TRTEngineOpTestBase, DynamicShapes) {
|
||||
|
||||
// Get the engine cache.
|
||||
TRTEngineCacheResource* cache_resource = nullptr;
|
||||
TF_ASSERT_OK(device_->resource_manager()->Lookup("TF-TRT-Engine-Cache",
|
||||
"myop", &cache_resource));
|
||||
TF_ASSERT_OK(
|
||||
device_->resource_manager()->Lookup("TF-TRT", "myop", &cache_resource));
|
||||
core::ScopedUnref sc(cache_resource);
|
||||
|
||||
// It should contain only one engine.
|
||||
|
@ -40,10 +40,9 @@ namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
using ::nvinfer1::IRuntime;
|
||||
|
||||
class CreateTRTEngineCacheHandle : public OpKernel {
|
||||
class CreateTRTResourceHandle : public OpKernel {
|
||||
public:
|
||||
explicit CreateTRTEngineCacheHandle(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {
|
||||
explicit CreateTRTResourceHandle(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("resource_name", &resource_name_));
|
||||
}
|
||||
|
||||
@ -60,7 +59,7 @@ class CreateTRTEngineCacheHandle : public OpKernel {
|
||||
<< resource_name_ << " on device " << ctx->device()->name();
|
||||
handle_.scalar<ResourceHandle>()() =
|
||||
MakeResourceHandle<TRTEngineCacheResource>(
|
||||
ctx, std::string(kCacheContainerName), resource_name_);
|
||||
ctx, std::string(kTfTrtContainerName), resource_name_);
|
||||
initialized_ = true;
|
||||
}
|
||||
}
|
||||
@ -73,17 +72,17 @@ class CreateTRTEngineCacheHandle : public OpKernel {
|
||||
mutex mutex_;
|
||||
bool initialized_ GUARDED_BY(mutex_) = false;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CreateTRTEngineCacheHandle);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CreateTRTResourceHandle);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateTRTEngineCacheHandle")
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateTRTResourceHandle")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("engine_cache_handle"),
|
||||
CreateTRTEngineCacheHandle);
|
||||
.HostMemory("resource_handle"),
|
||||
CreateTRTResourceHandle);
|
||||
|
||||
class InitializeTRTEngineOp : public OpKernel {
|
||||
class InitializeTRTResource : public OpKernel {
|
||||
public:
|
||||
explicit InitializeTRTEngineOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
explicit InitializeTRTResource(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr("max_cached_engines_count", &max_cached_engines_));
|
||||
}
|
||||
@ -156,19 +155,18 @@ class InitializeTRTEngineOp : public OpKernel {
|
||||
// Maximum number of cached engines
|
||||
int max_cached_engines_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(InitializeTRTEngineOp);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(InitializeTRTResource);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("InitializeTRTEngineOp")
|
||||
REGISTER_KERNEL_BUILDER(Name("InitializeTRTResource")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("engine_cache_handle"),
|
||||
InitializeTRTEngineOp);
|
||||
.HostMemory("resource_handle"),
|
||||
InitializeTRTResource);
|
||||
|
||||
class SerializeTRTEngineOp : public OpKernel {
|
||||
class SerializeTRTResource : public OpKernel {
|
||||
public:
|
||||
explicit SerializeTRTEngineOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("delete_cache_after_dump",
|
||||
&delete_cache_after_dump_));
|
||||
explicit SerializeTRTResource(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("delete_resource", &delete_resource_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
@ -179,7 +177,7 @@ class SerializeTRTEngineOp : public OpKernel {
|
||||
|
||||
TRTEngineCacheResource* resource = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->resource_manager()->Lookup(std::string(kCacheContainerName),
|
||||
ctx, ctx->resource_manager()->Lookup(std::string(kTfTrtContainerName),
|
||||
resource_name, &resource));
|
||||
core::ScopedUnref unref_me(resource);
|
||||
|
||||
@ -212,23 +210,23 @@ class SerializeTRTEngineOp : public OpKernel {
|
||||
<< " TRT engines for op " << resource_name << " on device "
|
||||
<< ctx->device()->name() << " to file " << filename;
|
||||
|
||||
if (delete_cache_after_dump_) {
|
||||
if (delete_resource_) {
|
||||
VLOG(1) << "Destroying TRT engine cache resource for op " << resource_name
|
||||
<< " on device " << ctx->device()->name();
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->resource_manager()->Delete<TRTEngineCacheResource>(
|
||||
std::string(kCacheContainerName), resource_name));
|
||||
std::string(kTfTrtContainerName), resource_name));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool delete_cache_after_dump_ = false;
|
||||
bool delete_resource_ = false;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SerializeTRTEngineOp);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SerializeTRTResource);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("SerializeTRTEngineOp").Device(DEVICE_GPU),
|
||||
SerializeTRTEngineOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("SerializeTRTResource").Device(DEVICE_GPU),
|
||||
SerializeTRTResource);
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
@ -92,10 +92,10 @@ TEST_F(TRTEngineResourceOpsTest, Basic) {
|
||||
SetDevice(DEVICE_GPU, std::move(device));
|
||||
|
||||
// Create the resource handle.
|
||||
const string container(kCacheContainerName);
|
||||
const string container(kTfTrtContainerName);
|
||||
const string resource_name = "myresource";
|
||||
Reset();
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTEngineCacheHandle")
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTResourceHandle")
|
||||
.Attr("resource_name", resource_name)
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
@ -107,7 +107,7 @@ TEST_F(TRTEngineResourceOpsTest, Basic) {
|
||||
EXPECT_TRUE(
|
||||
errors::IsNotFound(rm->Lookup(container, resource_name, &resource)));
|
||||
|
||||
// Create the resouce using an empty file with InitializeTRTEngineOp.
|
||||
// Create the resouce using an empty file with InitializeTRTResource.
|
||||
Reset();
|
||||
Env* env = Env::Default();
|
||||
const string filename = io::JoinPath(testing::TmpDir(), "trt_engine_file");
|
||||
@ -115,7 +115,7 @@ TEST_F(TRTEngineResourceOpsTest, Basic) {
|
||||
std::unique_ptr<WritableFile> file;
|
||||
TF_ASSERT_OK(env->NewWritableFile(filename, &file));
|
||||
}
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "InitializeTRTEngineOp")
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "InitializeTRTResource")
|
||||
.Input(FakeInput(DT_RESOURCE))
|
||||
.Input(FakeInput(DT_STRING))
|
||||
.Attr("max_cached_engines_count", 1)
|
||||
@ -136,10 +136,10 @@ TEST_F(TRTEngineResourceOpsTest, Basic) {
|
||||
absl::make_unique<EngineContext>(std::move(engine), std::move(context)));
|
||||
resource->Unref();
|
||||
|
||||
// Serialize the engine using SerializeTRTEngineOp op.
|
||||
// Serialize the engine using SerializeTRTResource op.
|
||||
Reset();
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "SerializeTRTEngineOp")
|
||||
.Attr("delete_cache_after_dump", true)
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "SerializeTRTResource")
|
||||
.Attr("delete_resource", true)
|
||||
.Input(FakeInput(DT_STRING))
|
||||
.Input(FakeInput(DT_STRING))
|
||||
.Finalize(node_def()));
|
||||
@ -175,7 +175,7 @@ TEST_F(TRTEngineResourceOpsTest, Basic) {
|
||||
|
||||
// Recreate the cache resource.
|
||||
Reset();
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "InitializeTRTEngineOp")
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "InitializeTRTResource")
|
||||
.Input(FakeInput(DT_RESOURCE))
|
||||
.Input(FakeInput(DT_STRING))
|
||||
.Attr("max_cached_engines_count", 1)
|
||||
|
@ -24,21 +24,21 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("CreateTRTEngineCacheHandle")
|
||||
REGISTER_OP("CreateTRTResourceHandle")
|
||||
.Attr("resource_name: string")
|
||||
.Output("engine_cache_handle: resource")
|
||||
.Output("resource_handle: resource")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("InitializeTRTEngineOp")
|
||||
REGISTER_OP("InitializeTRTResource")
|
||||
.Attr("max_cached_engines_count: int = 1")
|
||||
.Input("engine_cache_handle: resource")
|
||||
.Input("resource_handle: resource")
|
||||
.Input("filename: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
REGISTER_OP("SerializeTRTEngineOp")
|
||||
.Attr("delete_cache_after_dump: bool = false")
|
||||
REGISTER_OP("SerializeTRTResource")
|
||||
.Attr("delete_resource: bool = false")
|
||||
.Input("resource_name: string")
|
||||
.Input("filename: string")
|
||||
.SetIsStateful()
|
||||
|
@ -30,7 +30,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
const absl::string_view kCacheContainerName = "TF-TRT-Engine-Cache";
|
||||
const absl::string_view kTfTrtContainerName = "TF-TRT";
|
||||
|
||||
Logger& TRTEngineCacheResource::GetLogger() {
|
||||
static Logger* logger = new Logger();
|
||||
|
@ -170,7 +170,7 @@ class CalibrationContext {
|
||||
std::unique_ptr<std::thread> thr_;
|
||||
};
|
||||
|
||||
ABSL_CONST_INIT extern const absl::string_view kCacheContainerName;
|
||||
ABSL_CONST_INIT extern const absl::string_view kTfTrtContainerName;
|
||||
|
||||
class TRTEngineCacheResource : public ResourceBase {
|
||||
public:
|
||||
|
@ -42,9 +42,9 @@ const std::unordered_set<std::string>* GetExcludedOps() {
|
||||
"QuantizedMatMulWithBiasAndReluAndRequantize",
|
||||
#endif // INTEL_MKL
|
||||
#ifdef GOOGLE_TENSORRT
|
||||
"CreateTRTEngineCacheHandle",
|
||||
"InitializeTRTEngineOp",
|
||||
"SerializeTRTEngineOp",
|
||||
"CreateTRTResourceHandle",
|
||||
"InitializeTRTResource",
|
||||
"SerializeTRTResource",
|
||||
"GetCalibrationDataOp",
|
||||
"TRTEngineOp",
|
||||
#endif // GOOGLE_TENSORRT
|
||||
|
@ -717,7 +717,7 @@ class TrtGraphConverter(object):
|
||||
|
||||
def _get_resource_handle(name, device):
|
||||
with ops.device(device):
|
||||
return gen_trt_ops.create_trt_engine_cache_handle(resource_name=name)
|
||||
return gen_trt_ops.create_trt_resource_handle(resource_name=name)
|
||||
|
||||
|
||||
class TRTEngineResourceDeleter(tracking.CapturableResourceDeleter):
|
||||
@ -748,14 +748,14 @@ class TRTEngineResource(tracking.TrackableResource):
|
||||
self._resource_name = resource_name
|
||||
# Track the serialized engine file in the SavedModel.
|
||||
self._filename = self._track_trackable(
|
||||
tracking.TrackableAsset(filename), "_serialized_trt_engine_filename")
|
||||
tracking.TrackableAsset(filename), "_serialized_trt_resource_filename")
|
||||
self._maximum_cached_engines = maximum_cached_engines
|
||||
|
||||
def _create_resource(self):
|
||||
return _get_resource_handle(self._resource_name, self._resource_device)
|
||||
|
||||
def _initialize(self):
|
||||
gen_trt_ops.initialize_trt_engine_op(
|
||||
gen_trt_ops.initialize_trt_resource(
|
||||
self.resource_handle,
|
||||
self._filename,
|
||||
max_cached_engines_count=self._maximum_cached_engines)
|
||||
@ -930,10 +930,10 @@ class TrtGraphConverterV2(object):
|
||||
filename = os.path.join(engine_asset_dir,
|
||||
"trt-serialized-engine." + canonical_engine_name)
|
||||
try:
|
||||
gen_trt_ops.serialize_trt_engine_op(
|
||||
gen_trt_ops.serialize_trt_resource(
|
||||
resource_name=canonical_engine_name,
|
||||
filename=filename,
|
||||
delete_cache_after_dump=True)
|
||||
delete_resource=True)
|
||||
except errors.NotFoundError:
|
||||
# If user haven't run the function to populate the engine, it's fine,
|
||||
# and we don't need to track any serialized TRT engines.
|
||||
|
@ -413,7 +413,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def _destroy_cache():
|
||||
with ops.device("GPU:0"):
|
||||
handle = gen_trt_ops.create_trt_engine_cache_handle(
|
||||
handle = gen_trt_ops.create_trt_resource_handle(
|
||||
resource_name="TRTEngineOp_0")
|
||||
gen_resource_variable_ops.destroy_resource_op(
|
||||
handle, ignore_lookup_error=False)
|
||||
|
Loading…
Reference in New Issue
Block a user