Commit 960c6a24 authored by titicaca's avatar titicaca
Browse files

Issue #2 fix the bug of constant features when building itree

parent e85acc1e
......@@ -498,6 +498,9 @@ class IForest (
// calculate actual maxDepth to limit tree height
val longestPath = Math.ceil(Math.log(Math.max(2, points.length)) / Math.log(2)).toInt
val possibleMaxDepth = if ($(maxDepth) > longestPath) longestPath else $(maxDepth)
if(possibleMaxDepth != $(maxDepth)) {
logger.warn("building itree using possible max depth " + possibleMaxDepth + ", instead of " + $(maxDepth))
}
val numFeatures = trainData.head.size
// a array stores constant features index
......@@ -559,12 +562,12 @@ class IForest (
/**
* Builds a tree
*
* @param data Input data, a two dimensional array, can be regearded as a table, each row
* @param data Input data, a two dimensional array, can be regarded as a table, each row
* is an instance, each column is a feature value.
* @param currentPathLength current node's path length
* @param maxDepth height limit during building a tree
* @param constantFeatures an array stores constant features indices, feature
* will not to draw when it is contant
* @param constantFeatures an array stores constant features indices, constant features
* will not be drawn
* @return tree's root node
*/
private[iforest] def iTree(data: Array[Array[Double]],
......@@ -588,16 +591,15 @@ class IForest (
var findConstant = true
while (findConstant && numFeatures != constantFeatureIndex) {
// select randomly a feature index
attrIndex = constantFeatures(rng.nextInt(numFeatures - constantFeatureIndex) +
constantFeatureIndex)
val idx = rng.nextInt(numFeatures - constantFeatureIndex) + constantFeatureIndex
attrIndex = constantFeatures(idx)
val features = Array.tabulate(data.length)( i => data(i)(attrIndex))
attrMin = features.min
attrMax = features.max
if (attrMin == attrMax) {
// swap constant feature index with non-constant feature index
val tmp = constantFeatures(attrIndex)
constantFeatures(constantFeatureIndex) = tmp
constantFeatures(attrIndex) = constantFeatures(constantFeatureIndex)
constantFeatures(idx) = constantFeatures(constantFeatureIndex)
constantFeatures(constantFeatureIndex) = attrIndex
// constant feature index add 1, then update
constantFeatureIndex += 1
constantFeatures(constantFeatures.length - 1) = constantFeatureIndex
......
......@@ -100,12 +100,14 @@ class IForestSuite extends SparkFunSuite with MLlibTestSparkContext with Default
.setPredictionCol(predictionColName)
.setAnomalyScoreCol(anomalyScoreName)
.setContamination(0.2)
.setSeed(123L)
val model = iforest.fit(dataset)
assert(model.trees.length === 10)
val summary = model.summary
val anomalies = summary.anomalies.collect
assert(anomalies.length === 10)
// TODO In Spark 2.3.x, function approxQuantile seems to be changed, numAnomalies might be not accurate.
assert(summary.numAnomalies === 2)
val transformed = model.transform(dataset)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment