Fix some internal tests.
PiperOrigin-RevId: 161093488
This commit is contained in:
parent
9cf4444655
commit
4de7361c43
@ -20,7 +20,7 @@
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tensorforest {
|
namespace tensorforest {
|
||||||
|
|
||||||
typedef TTypes<float, 1>::ConstTensor SingleDimStorageType;
|
typedef TTypes<float, 1>::UnalignedConstTensor SingleDimStorageType;
|
||||||
|
|
||||||
// Base class for classes that hold labels and weights. Mostly for testing
|
// Base class for classes that hold labels and weights. Mostly for testing
|
||||||
// purposes, because it's inconvenient to construct nasty Eigen::things.
|
// purposes, because it's inconvenient to construct nasty Eigen::things.
|
||||||
@ -54,9 +54,10 @@ class StoredInputTarget : public InputTarget {
|
|||||||
class TensorInputTarget : public StoredInputTarget<SingleDimStorageType> {
|
class TensorInputTarget : public StoredInputTarget<SingleDimStorageType> {
|
||||||
public:
|
public:
|
||||||
TensorInputTarget(const Tensor& target, const Tensor& weight, int num_targets)
|
TensorInputTarget(const Tensor& target, const Tensor& weight, int num_targets)
|
||||||
: StoredInputTarget(new SingleDimStorageType(target.tensor<float, 1>()),
|
: StoredInputTarget(
|
||||||
new SingleDimStorageType(weight.tensor<float, 1>()),
|
new SingleDimStorageType(target.unaligned_flat<float>()),
|
||||||
num_targets),
|
new SingleDimStorageType(weight.unaligned_flat<float>()),
|
||||||
|
num_targets),
|
||||||
original_tensor_(target) {}
|
original_tensor_(target) {}
|
||||||
|
|
||||||
int32 GetTargetAsClassIndex(int example_index,
|
int32 GetTargetAsClassIndex(int example_index,
|
||||||
|
Loading…
Reference in New Issue
Block a user