Update the TRT resource op names.

PiperOrigin-RevId: 261197551
This commit is contained in:
Guangda Lai 2019-08-01 14:26:11 -07:00 committed by TensorFlower Gardener
parent c38d739120
commit ebe2e50f36
11 changed files with 52 additions and 54 deletions

View File

@ -44,7 +44,7 @@ class GetCalibrationDataOp : public OpKernel {
// Get the resource.
TRTEngineCacheResource* resource = nullptr;
OP_REQUIRES_OK(context, context->resource_manager()->Lookup(
std::string(kCacheContainerName), resource_name,
std::string(kTfTrtContainerName), resource_name,
&resource));
core::ScopedUnref sc(resource);

View File

@ -618,7 +618,7 @@ Status TRTEngineOp::GetEngineCacheResource(OpKernelContext* ctx,
// Get engine cache.
return ctx->resource_manager()->LookupOrCreate(
std::string(kCacheContainerName), std::string(resource_name), cache_res,
std::string(kTfTrtContainerName), std::string(resource_name), cache_res,
{[this, ctx](TRTEngineCacheResource** cr) -> Status {
*cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_);
if (calibration_mode_) {

View File

@ -126,8 +126,8 @@ TEST_F(TRTEngineOpTestBase, DynamicShapes) {
// Get the engine cache.
TRTEngineCacheResource* cache_resource = nullptr;
TF_ASSERT_OK(device_->resource_manager()->Lookup("TF-TRT-Engine-Cache",
"myop", &cache_resource));
TF_ASSERT_OK(
device_->resource_manager()->Lookup("TF-TRT", "myop", &cache_resource));
core::ScopedUnref sc(cache_resource);
// It should contain only one engine.

View File

@ -40,10 +40,9 @@ namespace tensorflow {
namespace tensorrt {
using ::nvinfer1::IRuntime;
class CreateTRTEngineCacheHandle : public OpKernel {
class CreateTRTResourceHandle : public OpKernel {
public:
explicit CreateTRTEngineCacheHandle(OpKernelConstruction* ctx)
: OpKernel(ctx) {
explicit CreateTRTResourceHandle(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("resource_name", &resource_name_));
}
@ -60,7 +59,7 @@ class CreateTRTEngineCacheHandle : public OpKernel {
<< resource_name_ << " on device " << ctx->device()->name();
handle_.scalar<ResourceHandle>()() =
MakeResourceHandle<TRTEngineCacheResource>(
ctx, std::string(kCacheContainerName), resource_name_);
ctx, std::string(kTfTrtContainerName), resource_name_);
initialized_ = true;
}
}
@ -73,17 +72,17 @@ class CreateTRTEngineCacheHandle : public OpKernel {
mutex mutex_;
bool initialized_ GUARDED_BY(mutex_) = false;
TF_DISALLOW_COPY_AND_ASSIGN(CreateTRTEngineCacheHandle);
TF_DISALLOW_COPY_AND_ASSIGN(CreateTRTResourceHandle);
};
REGISTER_KERNEL_BUILDER(Name("CreateTRTEngineCacheHandle")
REGISTER_KERNEL_BUILDER(Name("CreateTRTResourceHandle")
.Device(DEVICE_GPU)
.HostMemory("engine_cache_handle"),
CreateTRTEngineCacheHandle);
.HostMemory("resource_handle"),
CreateTRTResourceHandle);
class InitializeTRTEngineOp : public OpKernel {
class InitializeTRTResource : public OpKernel {
public:
explicit InitializeTRTEngineOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
explicit InitializeTRTResource(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(
ctx, ctx->GetAttr("max_cached_engines_count", &max_cached_engines_));
}
@ -156,19 +155,18 @@ class InitializeTRTEngineOp : public OpKernel {
// Maximum number of cached engines
int max_cached_engines_;
TF_DISALLOW_COPY_AND_ASSIGN(InitializeTRTEngineOp);
TF_DISALLOW_COPY_AND_ASSIGN(InitializeTRTResource);
};
REGISTER_KERNEL_BUILDER(Name("InitializeTRTEngineOp")
REGISTER_KERNEL_BUILDER(Name("InitializeTRTResource")
.Device(DEVICE_GPU)
.HostMemory("engine_cache_handle"),
InitializeTRTEngineOp);
.HostMemory("resource_handle"),
InitializeTRTResource);
class SerializeTRTEngineOp : public OpKernel {
class SerializeTRTResource : public OpKernel {
public:
explicit SerializeTRTEngineOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("delete_cache_after_dump",
&delete_cache_after_dump_));
explicit SerializeTRTResource(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("delete_resource", &delete_resource_));
}
void Compute(OpKernelContext* ctx) override {
@ -179,7 +177,7 @@ class SerializeTRTEngineOp : public OpKernel {
TRTEngineCacheResource* resource = nullptr;
OP_REQUIRES_OK(
ctx, ctx->resource_manager()->Lookup(std::string(kCacheContainerName),
ctx, ctx->resource_manager()->Lookup(std::string(kTfTrtContainerName),
resource_name, &resource));
core::ScopedUnref unref_me(resource);
@ -212,23 +210,23 @@ class SerializeTRTEngineOp : public OpKernel {
<< " TRT engines for op " << resource_name << " on device "
<< ctx->device()->name() << " to file " << filename;
if (delete_cache_after_dump_) {
if (delete_resource_) {
VLOG(1) << "Destroying TRT engine cache resource for op " << resource_name
<< " on device " << ctx->device()->name();
OP_REQUIRES_OK(ctx,
ctx->resource_manager()->Delete<TRTEngineCacheResource>(
std::string(kCacheContainerName), resource_name));
std::string(kTfTrtContainerName), resource_name));
}
}
private:
bool delete_cache_after_dump_ = false;
bool delete_resource_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(SerializeTRTEngineOp);
TF_DISALLOW_COPY_AND_ASSIGN(SerializeTRTResource);
};
REGISTER_KERNEL_BUILDER(Name("SerializeTRTEngineOp").Device(DEVICE_GPU),
SerializeTRTEngineOp);
REGISTER_KERNEL_BUILDER(Name("SerializeTRTResource").Device(DEVICE_GPU),
SerializeTRTResource);
} // namespace tensorrt
} // namespace tensorflow

View File

@ -92,10 +92,10 @@ TEST_F(TRTEngineResourceOpsTest, Basic) {
SetDevice(DEVICE_GPU, std::move(device));
// Create the resource handle.
const string container(kCacheContainerName);
const string container(kTfTrtContainerName);
const string resource_name = "myresource";
Reset();
TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTEngineCacheHandle")
TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTResourceHandle")
.Attr("resource_name", resource_name)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
@ -107,7 +107,7 @@ TEST_F(TRTEngineResourceOpsTest, Basic) {
EXPECT_TRUE(
errors::IsNotFound(rm->Lookup(container, resource_name, &resource)));
// Create the resouce using an empty file with InitializeTRTEngineOp.
// Create the resouce using an empty file with InitializeTRTResource.
Reset();
Env* env = Env::Default();
const string filename = io::JoinPath(testing::TmpDir(), "trt_engine_file");
@ -115,7 +115,7 @@ TEST_F(TRTEngineResourceOpsTest, Basic) {
std::unique_ptr<WritableFile> file;
TF_ASSERT_OK(env->NewWritableFile(filename, &file));
}
TF_ASSERT_OK(NodeDefBuilder("op", "InitializeTRTEngineOp")
TF_ASSERT_OK(NodeDefBuilder("op", "InitializeTRTResource")
.Input(FakeInput(DT_RESOURCE))
.Input(FakeInput(DT_STRING))
.Attr("max_cached_engines_count", 1)
@ -136,10 +136,10 @@ TEST_F(TRTEngineResourceOpsTest, Basic) {
absl::make_unique<EngineContext>(std::move(engine), std::move(context)));
resource->Unref();
// Serialize the engine using SerializeTRTEngineOp op.
// Serialize the engine using SerializeTRTResource op.
Reset();
TF_ASSERT_OK(NodeDefBuilder("op", "SerializeTRTEngineOp")
.Attr("delete_cache_after_dump", true)
TF_ASSERT_OK(NodeDefBuilder("op", "SerializeTRTResource")
.Attr("delete_resource", true)
.Input(FakeInput(DT_STRING))
.Input(FakeInput(DT_STRING))
.Finalize(node_def()));
@ -175,7 +175,7 @@ TEST_F(TRTEngineResourceOpsTest, Basic) {
// Recreate the cache resource.
Reset();
TF_ASSERT_OK(NodeDefBuilder("op", "InitializeTRTEngineOp")
TF_ASSERT_OK(NodeDefBuilder("op", "InitializeTRTResource")
.Input(FakeInput(DT_RESOURCE))
.Input(FakeInput(DT_STRING))
.Attr("max_cached_engines_count", 1)

View File

@ -24,21 +24,21 @@ limitations under the License.
namespace tensorflow {
REGISTER_OP("CreateTRTEngineCacheHandle")
REGISTER_OP("CreateTRTResourceHandle")
.Attr("resource_name: string")
.Output("engine_cache_handle: resource")
.Output("resource_handle: resource")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("InitializeTRTEngineOp")
REGISTER_OP("InitializeTRTResource")
.Attr("max_cached_engines_count: int = 1")
.Input("engine_cache_handle: resource")
.Input("resource_handle: resource")
.Input("filename: string")
.SetIsStateful()
.SetShapeFn(shape_inference::NoOutputs);
REGISTER_OP("SerializeTRTEngineOp")
.Attr("delete_cache_after_dump: bool = false")
REGISTER_OP("SerializeTRTResource")
.Attr("delete_resource: bool = false")
.Input("resource_name: string")
.Input("filename: string")
.SetIsStateful()

View File

@ -30,7 +30,7 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
const absl::string_view kCacheContainerName = "TF-TRT-Engine-Cache";
const absl::string_view kTfTrtContainerName = "TF-TRT";
Logger& TRTEngineCacheResource::GetLogger() {
static Logger* logger = new Logger();

View File

@ -170,7 +170,7 @@ class CalibrationContext {
std::unique_ptr<std::thread> thr_;
};
ABSL_CONST_INIT extern const absl::string_view kCacheContainerName;
ABSL_CONST_INIT extern const absl::string_view kTfTrtContainerName;
class TRTEngineCacheResource : public ResourceBase {
public:

View File

@ -42,9 +42,9 @@ const std::unordered_set<std::string>* GetExcludedOps() {
"QuantizedMatMulWithBiasAndReluAndRequantize",
#endif // INTEL_MKL
#ifdef GOOGLE_TENSORRT
"CreateTRTEngineCacheHandle",
"InitializeTRTEngineOp",
"SerializeTRTEngineOp",
"CreateTRTResourceHandle",
"InitializeTRTResource",
"SerializeTRTResource",
"GetCalibrationDataOp",
"TRTEngineOp",
#endif // GOOGLE_TENSORRT

View File

@ -717,7 +717,7 @@ class TrtGraphConverter(object):
def _get_resource_handle(name, device):
with ops.device(device):
return gen_trt_ops.create_trt_engine_cache_handle(resource_name=name)
return gen_trt_ops.create_trt_resource_handle(resource_name=name)
class TRTEngineResourceDeleter(tracking.CapturableResourceDeleter):
@ -748,14 +748,14 @@ class TRTEngineResource(tracking.TrackableResource):
self._resource_name = resource_name
# Track the serialized engine file in the SavedModel.
self._filename = self._track_trackable(
tracking.TrackableAsset(filename), "_serialized_trt_engine_filename")
tracking.TrackableAsset(filename), "_serialized_trt_resource_filename")
self._maximum_cached_engines = maximum_cached_engines
def _create_resource(self):
return _get_resource_handle(self._resource_name, self._resource_device)
def _initialize(self):
gen_trt_ops.initialize_trt_engine_op(
gen_trt_ops.initialize_trt_resource(
self.resource_handle,
self._filename,
max_cached_engines_count=self._maximum_cached_engines)
@ -930,10 +930,10 @@ class TrtGraphConverterV2(object):
filename = os.path.join(engine_asset_dir,
"trt-serialized-engine." + canonical_engine_name)
try:
gen_trt_ops.serialize_trt_engine_op(
gen_trt_ops.serialize_trt_resource(
resource_name=canonical_engine_name,
filename=filename,
delete_cache_after_dump=True)
delete_resource=True)
except errors.NotFoundError:
# If user haven't run the function to populate the engine, it's fine,
# and we don't need to track any serialized TRT engines.

View File

@ -413,7 +413,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
def _destroy_cache():
with ops.device("GPU:0"):
handle = gen_trt_ops.create_trt_engine_cache_handle(
handle = gen_trt_ops.create_trt_resource_handle(
resource_name="TRTEngineOp_0")
gen_resource_variable_ops.destroy_resource_op(
handle, ignore_lookup_error=False)