Allow kernel unit tests to run on GPU

PiperOrigin-RevId: 163705027
This commit is contained in:
A. Unique TensorFlower 2017-07-31 09:43:04 -07:00 committed by TensorFlower Gardener
parent 4ec29c5d95
commit f6f07b0275

View File

@ -83,6 +83,14 @@ class OpsTestBase : public ::testing::Test {
params_.reset(nullptr);
}
// Allow kernel unit tests to run on GPU
void SetDevice(const DeviceType& device_type,
std::unique_ptr<Device> device) {
CHECK(device_.get()) << "No device provided";
device_type_ = device_type;
device_ = std::move(device);
}
void set_node_def(const NodeDef& node_def) { node_def_.CopyFrom(node_def); }
// Clients can manipulate the underlying NodeDef via this accessor.