Fix some internal tests.

PiperOrigin-RevId: 161093488
This commit is contained in:
A. Unique TensorFlower 2017-07-06 10:01:25 -07:00 committed by TensorFlower Gardener
parent 9cf4444655
commit 4de7361c43

View File

@ -20,7 +20,7 @@
namespace tensorflow { namespace tensorflow {
namespace tensorforest { namespace tensorforest {
typedef TTypes<float, 1>::ConstTensor SingleDimStorageType; typedef TTypes<float, 1>::UnalignedConstTensor SingleDimStorageType;
// Base class for classes that hold labels and weights. Mostly for testing // Base class for classes that hold labels and weights. Mostly for testing
// purposes, because it's inconvenient to construct nasty Eigen::things. // purposes, because it's inconvenient to construct nasty Eigen::things.
@ -54,9 +54,10 @@ class StoredInputTarget : public InputTarget {
class TensorInputTarget : public StoredInputTarget<SingleDimStorageType> { class TensorInputTarget : public StoredInputTarget<SingleDimStorageType> {
public: public:
TensorInputTarget(const Tensor& target, const Tensor& weight, int num_targets) TensorInputTarget(const Tensor& target, const Tensor& weight, int num_targets)
: StoredInputTarget(new SingleDimStorageType(target.tensor<float, 1>()), : StoredInputTarget(
new SingleDimStorageType(weight.tensor<float, 1>()), new SingleDimStorageType(target.unaligned_flat<float>()),
num_targets), new SingleDimStorageType(weight.unaligned_flat<float>()),
num_targets),
original_tensor_(target) {} original_tensor_(target) {}
int32 GetTargetAsClassIndex(int example_index, int32 GetTargetAsClassIndex(int example_index,