Commit 9ba87e83 authored by fzyang's avatar fzyang
Browse files

fix for features sampling error

parent 35f0dbe5
......@@ -498,7 +498,7 @@ class IForest (
val _trees = rddPerTree.map {
case (treeId: Int, points: Array[Vector]) =>
// sample features
val trainData = sampleFeatures(points, $(maxFeatures))
val (trainData, featureIdxArr) = sampleFeatures(points, $(maxFeatures))
// calculate actual maxDepth to limit tree height
val longestPath = Math.ceil(Math.log(Math.max(2, points.length)) / Math.log(2)).toInt
......@@ -515,7 +515,7 @@ class IForest (
// last position's value indicates constant feature offset index
constantFeatures(numFeatures) = 0
// build a tree
iTree(trainData, 0, possibleMaxDepth, constantFeatures)
iTree(trainData, 0, possibleMaxDepth, constantFeatures, featureIdxArr)
}.collect()
......@@ -532,22 +532,26 @@ class IForest (
* Sample features to train a tree.
* @param data Input data to train a tree, each element is an instance.
* @param maxFeatures The number of features to draw.
* @return Sampled features dataset, a two dimensional array.
* @return Tuple (sampledFeaturesDataset, featureIdxArr),
* featureIdxArr is an array stores the origin feature idx before the feature sampling
*/
private[iforest] def sampleFeatures(
data: Array[Vector],
maxFeatures: Double): Array[Array[Double]] = {
maxFeatures: Double): (Array[Array[Double]], Array[Int]) = {
// get feature size
val numFeatures = data.head.size
// calculate the number of sampling features
val subFeatures: Int =
if (maxFeatures <= 1) (maxFeatures * numFeatures).toInt
else if (maxFeatures > numFeatures) numFeatures
else if (maxFeatures > numFeatures) {
logger.warn("maxFeatures is larger than the numFeatures, using all features instead")
numFeatures
}
else maxFeatures.toInt
if (maxFeatures == numFeatures) {
data.toArray.map(vector => vector.asInstanceOf[DenseVector].values)
if (subFeatures == numFeatures) {
(data.toArray.map(vector => vector.asInstanceOf[DenseVector].values), Array.range(0, numFeatures))
} else {
// feature index for sampling features
val featureIdx = rng.shuffle(0 to numFeatures - 1).take(subFeatures)
......@@ -558,9 +562,8 @@ class IForest (
featureIdx.zipWithIndex.foreach(elem => sampledValues(elem._2) = vector(elem._1))
sampledFeatures += sampledValues
})
sampledFeatures.result()
(sampledFeatures.result(), featureIdx.toArray)
}
}
/**
......@@ -572,12 +575,14 @@ class IForest (
* @param maxDepth height limit during building a tree
* @param constantFeatures an array stores constant features indices, constant features
* will not be drawn
* @param featureIdxArr an array stores the mapping from the sampled feature idx to the origin feature idx
* @return tree's root node
*/
private[iforest] def iTree(data: Array[Array[Double]],
currentPathLength: Int,
maxDepth: Int,
constantFeatures: Array[Int]): IFNode = {
constantFeatures: Array[Int],
featureIdxArr: Array[Int]): IFNode = {
var constantFeatureIndex = constantFeatures.last
// the condition of leaf node
......@@ -620,9 +625,9 @@ class IForest (
val rightData = data.filter(point => point(attrIndex) >= attrValue)
// recursively build a tree
new IFInternalNode(
iTree(leftData, currentPathLength + 1, maxDepth, constantFeatures.clone()),
iTree(rightData, currentPathLength + 1, maxDepth, constantFeatures.clone()),
attrIndex, attrValue)
iTree(leftData, currentPathLength + 1, maxDepth, constantFeatures.clone(), featureIdxArr),
iTree(rightData, currentPathLength + 1, maxDepth, constantFeatures.clone(), featureIdxArr),
featureIdxArr(attrIndex), attrValue)
}
}
}
......
......@@ -7,6 +7,8 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import scala.util.Random
class IForestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
......@@ -84,14 +86,16 @@ class IForestSuite extends SparkFunSuite with MLlibTestSparkContext with Default
test("sample features") {
val data = IForestSuite.generateIVectorArray(4, 3)
val iforest = new IForest()
val sampleResult = iforest.sampleFeatures(data, 4)
val iforest = new IForest().setSeed(123456L)
val (sampleResult, featureIdxArr) = iforest.sampleFeatures(data, 4)
assert(sampleResult.length === 4 && sampleResult(0).length === 3 &&
sampleResult(1).length === 3 && sampleResult(2).length === 3)
assert(featureIdxArr.length === 3 && featureIdxArr(0) === 0 && featureIdxArr(1) === 1 && featureIdxArr(2) === 2)
val sampleResult2 = iforest.sampleFeatures(data, 2)
val (sampleResult2, featureIdxArr2) = iforest.sampleFeatures(data, 2)
assert(sampleResult2.length === 4 && sampleResult2(0).length === 2 &&
sampleResult2(1).length === 2 && sampleResult2(2).length === 2)
assert(featureIdxArr2.length === 2)
}
test("fit, transform and summary") {
......@@ -103,9 +107,10 @@ class IForestSuite extends SparkFunSuite with MLlibTestSparkContext with Default
.setPredictionCol(predictionColName)
.setAnomalyScoreCol(anomalyScoreName)
.setContamination(0.2)
.setMaxFeatures(0.5)
.setSeed(123L)
val model = iforest.fit(dataset)
assert(model.trees.length === 10)
assert(model.trees.length === 10)
val summary = model.summary
val anomalies = summary.anomalies.collect
......@@ -248,9 +253,10 @@ object IForestSuite {
spark.createDataFrame(sc.parallelize(data))
}
def generateIVectorArray(row: Int, col: Int): Array[Vector] = {
def generateIVectorArray(row: Int, col: Int, seed: Long = 100L): Array[Vector] = {
val rand = new Random(seed)
Array.tabulate(row) { i =>
Vectors.dense(Array.fill(col)(i.toDouble))
Vectors.dense(Array.fill(col)(i.toDouble * rand.nextInt(10)))
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment