Commit 66b62427 authored by fzyang's avatar fzyang
Browse files

fix for random seed problem during generaring itree. Add param threshold in...

fix for random seed problem during generaring itree. Add param threshold in iforestModel, in order to predict in a new dataset
parent 588f9ab7
...@@ -46,6 +46,9 @@ class IForestModel ( ...@@ -46,6 +46,9 @@ class IForestModel (
private var trainingSummary: Option[IForestSummary] = None private var trainingSummary: Option[IForestSummary] = None
// Threshold for anomaly score. Default is -1.
private var threshold: Double = -1d
private[iforest] def setSummary(summary: Option[IForestSummary]): this.type = { private[iforest] def setSummary(summary: Option[IForestSummary]): this.type = {
this.trainingSummary = summary this.trainingSummary = summary
this this
...@@ -62,6 +65,15 @@ class IForestModel ( ...@@ -62,6 +65,15 @@ class IForestModel (
) )
} }
def getThreshold(): Double = {
this.threshold
}
def setThreshold(value: Double): this.type = {
this.threshold = value
this
}
/** /**
* Predict if a particular sample is an outlier or not. * Predict if a particular sample is an outlier or not.
* @param dataset Input data which is a dataset with n_samples rows. This dataset must have a * @param dataset Input data which is a dataset with n_samples rows. This dataset must have a
...@@ -85,12 +97,16 @@ class IForestModel ( ...@@ -85,12 +97,16 @@ class IForestModel (
} }
// append a score column // append a score column
val scoreDataset = dataset.withColumn($(anomalyScoreCol), scoreUDF(col($(featuresCol)))) val scoreDataset = dataset.withColumn($(anomalyScoreCol), scoreUDF(col($(featuresCol))))
// get threshold value
val threshold = scoreDataset.stat.approxQuantile($(anomalyScoreCol), if (threshold < 0) {
Array(1 - $(contamination)), $(approxQuantileRelativeError)) logger.info("threshold is not set, calculating the anomaly threshold according to param contamination..")
threshold = scoreDataset.stat.approxQuantile($(anomalyScoreCol),
Array(1 - $(contamination)), $(approxQuantileRelativeError))(0)
}
// set anomaly instance label 1 // set anomaly instance label 1
val predictUDF = udf { (anomalyScore: Double) => val predictUDF = udf { (anomalyScore: Double) =>
if (anomalyScore > threshold(0)) 1.0 else 0.0 if (anomalyScore > threshold) 1.0 else 0.0
} }
scoreDataset.withColumn($(predictionCol), predictUDF(col($(anomalyScoreCol)))) scoreDataset.withColumn($(predictionCol), predictUDF(col($(anomalyScoreCol))))
} }
...@@ -497,8 +513,11 @@ class IForest ( ...@@ -497,8 +513,11 @@ class IForest (
// build each tree and construct a forest // build each tree and construct a forest
val _trees = rddPerTree.map { val _trees = rddPerTree.map {
case (treeId: Int, points: Array[Vector]) => case (treeId: Int, points: Array[Vector]) =>
// Create a random for iTree generation
val random = new Random(rng.nextInt + treeId)
// sample features // sample features
val (trainData, featureIdxArr) = sampleFeatures(points, $(maxFeatures)) val (trainData, featureIdxArr) = sampleFeatures(points, $(maxFeatures), random)
// calculate actual maxDepth to limit tree height // calculate actual maxDepth to limit tree height
val longestPath = Math.ceil(Math.log(Math.max(2, points.length)) / Math.log(2)).toInt val longestPath = Math.ceil(Math.log(Math.max(2, points.length)) / Math.log(2)).toInt
...@@ -515,7 +534,7 @@ class IForest ( ...@@ -515,7 +534,7 @@ class IForest (
// last position's value indicates constant feature offset index // last position's value indicates constant feature offset index
constantFeatures(numFeatures) = 0 constantFeatures(numFeatures) = 0
// build a tree // build a tree
iTree(trainData, 0, possibleMaxDepth, constantFeatures, featureIdxArr) iTree(trainData, 0, possibleMaxDepth, constantFeatures, featureIdxArr, random)
}.collect() }.collect()
...@@ -524,6 +543,7 @@ class IForest ( ...@@ -524,6 +543,7 @@ class IForest (
val summary = new IForestSummary( val summary = new IForestSummary(
predictions, $(featuresCol), $(predictionCol), $(anomalyScoreCol) predictions, $(featuresCol), $(predictionCol), $(anomalyScoreCol)
) )
model.setSummary(Some(summary)) model.setSummary(Some(summary))
model model
} }
...@@ -537,7 +557,8 @@ class IForest ( ...@@ -537,7 +557,8 @@ class IForest (
*/ */
private[iforest] def sampleFeatures( private[iforest] def sampleFeatures(
data: Array[Vector], data: Array[Vector],
maxFeatures: Double): (Array[Array[Double]], Array[Int]) = { maxFeatures: Double,
random: Random = new Random()): (Array[Array[Double]], Array[Int]) = {
// get feature size // get feature size
val numFeatures = data.head.size val numFeatures = data.head.size
...@@ -554,7 +575,7 @@ class IForest ( ...@@ -554,7 +575,7 @@ class IForest (
(data.toArray.map(vector => vector.asInstanceOf[DenseVector].values), Array.range(0, numFeatures)) (data.toArray.map(vector => vector.asInstanceOf[DenseVector].values), Array.range(0, numFeatures))
} else { } else {
// feature index for sampling features // feature index for sampling features
val featureIdx = rng.shuffle(0 to numFeatures - 1).take(subFeatures) val featureIdx = random.shuffle(0 to numFeatures - 1).take(subFeatures)
val sampledFeatures = mutable.ArrayBuilder.make[Array[Double]] val sampledFeatures = mutable.ArrayBuilder.make[Array[Double]]
data.foreach(vector => { data.foreach(vector => {
...@@ -576,13 +597,15 @@ class IForest ( ...@@ -576,13 +597,15 @@ class IForest (
* @param constantFeatures an array stores constant features indices, constant features * @param constantFeatures an array stores constant features indices, constant features
* will not be drawn * will not be drawn
* @param featureIdxArr an array stores the mapping from the sampled feature idx to the origin feature idx * @param featureIdxArr an array stores the mapping from the sampled feature idx to the origin feature idx
* @param randomSeed random for generating iTree
* @return tree's root node * @return tree's root node
*/ */
private[iforest] def iTree(data: Array[Array[Double]], private[iforest] def iTree(data: Array[Array[Double]],
currentPathLength: Int, currentPathLength: Int,
maxDepth: Int, maxDepth: Int,
constantFeatures: Array[Int], constantFeatures: Array[Int],
featureIdxArr: Array[Int]): IFNode = { featureIdxArr: Array[Int],
random: Random): IFNode = {
var constantFeatureIndex = constantFeatures.last var constantFeatureIndex = constantFeatures.last
// the condition of leaf node // the condition of leaf node
...@@ -600,7 +623,7 @@ class IForest ( ...@@ -600,7 +623,7 @@ class IForest (
var findConstant = true var findConstant = true
while (findConstant && numFeatures != constantFeatureIndex) { while (findConstant && numFeatures != constantFeatureIndex) {
// select randomly a feature index // select randomly a feature index
val idx = rng.nextInt(numFeatures - constantFeatureIndex) + constantFeatureIndex val idx = random.nextInt(numFeatures - constantFeatureIndex) + constantFeatureIndex
attrIndex = constantFeatures(idx) attrIndex = constantFeatures(idx)
val features = Array.tabulate(data.length)( i => data(i)(attrIndex)) val features = Array.tabulate(data.length)( i => data(i)(attrIndex))
attrMin = features.min attrMin = features.min
...@@ -619,14 +642,14 @@ class IForest ( ...@@ -619,14 +642,14 @@ class IForest (
if (numFeatures == constantFeatureIndex) new IFLeafNode(data.length) if (numFeatures == constantFeatureIndex) new IFLeafNode(data.length)
else { else {
// select randomly a feature value between (attrMin, attrMax) // select randomly a feature value between (attrMin, attrMax)
val attrValue = rng.nextDouble() * (attrMax - attrMin) + attrMin val attrValue = random.nextDouble() * (attrMax - attrMin) + attrMin
// split data according to the attrValue // split data according to the attrValue
val leftData = data.filter(point => point(attrIndex) < attrValue) val leftData = data.filter(point => point(attrIndex) < attrValue)
val rightData = data.filter(point => point(attrIndex) >= attrValue) val rightData = data.filter(point => point(attrIndex) >= attrValue)
// recursively build a tree // recursively build a tree
new IFInternalNode( new IFInternalNode(
iTree(leftData, currentPathLength + 1, maxDepth, constantFeatures.clone(), featureIdxArr), iTree(leftData, currentPathLength + 1, maxDepth, constantFeatures.clone(), featureIdxArr, random),
iTree(rightData, currentPathLength + 1, maxDepth, constantFeatures.clone(), featureIdxArr), iTree(rightData, currentPathLength + 1, maxDepth, constantFeatures.clone(), featureIdxArr, random),
featureIdxArr(attrIndex), attrValue) featureIdxArr(attrIndex), attrValue)
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment