Allow kernel unit tests to run on GPU
PiperOrigin-RevId: 163705027
This commit is contained in:
parent
4ec29c5d95
commit
f6f07b0275
@ -83,6 +83,14 @@ class OpsTestBase : public ::testing::Test {
|
||||
params_.reset(nullptr);
|
||||
}
|
||||
|
||||
// Allow kernel unit tests to run on GPU
|
||||
void SetDevice(const DeviceType& device_type,
|
||||
std::unique_ptr<Device> device) {
|
||||
CHECK(device_.get()) << "No device provided";
|
||||
device_type_ = device_type;
|
||||
device_ = std::move(device);
|
||||
}
|
||||
|
||||
void set_node_def(const NodeDef& node_def) { node_def_.CopyFrom(node_def); }
|
||||
|
||||
// Clients can manipulate the underlying NodeDef via this accessor.
|
||||
|
Loading…
Reference in New Issue
Block a user