Fix some internal tests.
PiperOrigin-RevId: 161093488
This commit is contained in:
parent
9cf4444655
commit
4de7361c43
@ -20,7 +20,7 @@
|
||||
namespace tensorflow {
|
||||
namespace tensorforest {
|
||||
|
||||
typedef TTypes<float, 1>::ConstTensor SingleDimStorageType;
|
||||
typedef TTypes<float, 1>::UnalignedConstTensor SingleDimStorageType;
|
||||
|
||||
// Base class for classes that hold labels and weights. Mostly for testing
|
||||
// purposes, because it's inconvenient to construct nasty Eigen::things.
|
||||
@ -54,8 +54,9 @@ class StoredInputTarget : public InputTarget {
|
||||
class TensorInputTarget : public StoredInputTarget<SingleDimStorageType> {
|
||||
public:
|
||||
TensorInputTarget(const Tensor& target, const Tensor& weight, int num_targets)
|
||||
: StoredInputTarget(new SingleDimStorageType(target.tensor<float, 1>()),
|
||||
new SingleDimStorageType(weight.tensor<float, 1>()),
|
||||
: StoredInputTarget(
|
||||
new SingleDimStorageType(target.unaligned_flat<float>()),
|
||||
new SingleDimStorageType(weight.unaligned_flat<float>()),
|
||||
num_targets),
|
||||
original_tensor_(target) {}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user