Allow kernel unit tests to run on GPU
PiperOrigin-RevId: 163705027
This commit is contained in:
parent
4ec29c5d95
commit
f6f07b0275
@ -83,6 +83,14 @@ class OpsTestBase : public ::testing::Test {
|
|||||||
params_.reset(nullptr);
|
params_.reset(nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Allow kernel unit tests to run on GPU
|
||||||
|
void SetDevice(const DeviceType& device_type,
|
||||||
|
std::unique_ptr<Device> device) {
|
||||||
|
CHECK(device_.get()) << "No device provided";
|
||||||
|
device_type_ = device_type;
|
||||||
|
device_ = std::move(device);
|
||||||
|
}
|
||||||
|
|
||||||
void set_node_def(const NodeDef& node_def) { node_def_.CopyFrom(node_def); }
|
void set_node_def(const NodeDef& node_def) { node_def_.CopyFrom(node_def); }
|
||||||
|
|
||||||
// Clients can manipulate the underlying NodeDef via this accessor.
|
// Clients can manipulate the underlying NodeDef via this accessor.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user