Commit 66b62427 authored by fzyang's avatar fzyang
Browse files

fix for random seed problem during generaring itree. Add param threshold in...

fix for random seed problem during generaring itree. Add param threshold in iforestModel, in order to predict in a new dataset
parent 588f9ab7
......@@ -46,6 +46,9 @@ class IForestModel (
private var trainingSummary: Option[IForestSummary] = None
// Threshold for anomaly score. Default is -1.
private var threshold: Double = -1d
private[iforest] def setSummary(summary: Option[IForestSummary]): this.type = {
this.trainingSummary = summary
this
......@@ -62,6 +65,15 @@ class IForestModel (
)
}
def getThreshold(): Double = {
this.threshold
}
def setThreshold(value: Double): this.type = {
this.threshold = value
this
}
/**
* Predict if a particular sample is an outlier or not.
* @param dataset Input data which is a dataset with n_samples rows. This dataset must have a
......@@ -85,12 +97,16 @@ class IForestModel (
}
// append a score column
val scoreDataset = dataset.withColumn($(anomalyScoreCol), scoreUDF(col($(featuresCol))))
// get threshold value
val threshold = scoreDataset.stat.approxQuantile($(anomalyScoreCol),
Array(1 - $(contamination)), $(approxQuantileRelativeError))
if (threshold < 0) {
logger.info("threshold is not set, calculating the anomaly threshold according to param contamination..")
threshold = scoreDataset.stat.approxQuantile($(anomalyScoreCol),
Array(1 - $(contamination)), $(approxQuantileRelativeError))(0)
}
// set anomaly instance label 1
val predictUDF = udf { (anomalyScore: Double) =>
if (anomalyScore > threshold(0)) 1.0 else 0.0
if (anomalyScore > threshold) 1.0 else 0.0
}
scoreDataset.withColumn($(predictionCol), predictUDF(col($(anomalyScoreCol))))
}
......@@ -497,8 +513,11 @@ class IForest (
// build each tree and construct a forest
val _trees = rddPerTree.map {
case (treeId: Int, points: Array[Vector]) =>
// Create a random for iTree generation
val random = new Random(rng.nextInt + treeId)
// sample features
val (trainData, featureIdxArr) = sampleFeatures(points, $(maxFeatures))
val (trainData, featureIdxArr) = sampleFeatures(points, $(maxFeatures), random)
// calculate actual maxDepth to limit tree height
val longestPath = Math.ceil(Math.log(Math.max(2, points.length)) / Math.log(2)).toInt
......@@ -515,7 +534,7 @@ class IForest (
// last position's value indicates constant feature offset index
constantFeatures(numFeatures) = 0
// build a tree
iTree(trainData, 0, possibleMaxDepth, constantFeatures, featureIdxArr)
iTree(trainData, 0, possibleMaxDepth, constantFeatures, featureIdxArr, random)
}.collect()
......@@ -524,6 +543,7 @@ class IForest (
val summary = new IForestSummary(
predictions, $(featuresCol), $(predictionCol), $(anomalyScoreCol)
)
model.setSummary(Some(summary))
model
}
......@@ -537,7 +557,8 @@ class IForest (
*/
private[iforest] def sampleFeatures(
data: Array[Vector],
maxFeatures: Double): (Array[Array[Double]], Array[Int]) = {
maxFeatures: Double,
random: Random = new Random()): (Array[Array[Double]], Array[Int]) = {
// get feature size
val numFeatures = data.head.size
......@@ -554,7 +575,7 @@ class IForest (
(data.toArray.map(vector => vector.asInstanceOf[DenseVector].values), Array.range(0, numFeatures))
} else {
// feature index for sampling features
val featureIdx = rng.shuffle(0 to numFeatures - 1).take(subFeatures)
val featureIdx = random.shuffle(0 to numFeatures - 1).take(subFeatures)
val sampledFeatures = mutable.ArrayBuilder.make[Array[Double]]
data.foreach(vector => {
......@@ -576,13 +597,15 @@ class IForest (
* @param constantFeatures an array stores constant features indices, constant features
* will not be drawn
* @param featureIdxArr an array stores the mapping from the sampled feature idx to the origin feature idx
* @param randomSeed random for generating iTree
* @return tree's root node
*/
private[iforest] def iTree(data: Array[Array[Double]],
currentPathLength: Int,
maxDepth: Int,
constantFeatures: Array[Int],
featureIdxArr: Array[Int]): IFNode = {
featureIdxArr: Array[Int],
random: Random): IFNode = {
var constantFeatureIndex = constantFeatures.last
// the condition of leaf node
......@@ -600,7 +623,7 @@ class IForest (
var findConstant = true
while (findConstant && numFeatures != constantFeatureIndex) {
// select randomly a feature index
val idx = rng.nextInt(numFeatures - constantFeatureIndex) + constantFeatureIndex
val idx = random.nextInt(numFeatures - constantFeatureIndex) + constantFeatureIndex
attrIndex = constantFeatures(idx)
val features = Array.tabulate(data.length)( i => data(i)(attrIndex))
attrMin = features.min
......@@ -619,14 +642,14 @@ class IForest (
if (numFeatures == constantFeatureIndex) new IFLeafNode(data.length)
else {
// select randomly a feature value between (attrMin, attrMax)
val attrValue = rng.nextDouble() * (attrMax - attrMin) + attrMin
val attrValue = random.nextDouble() * (attrMax - attrMin) + attrMin
// split data according to the attrValue
val leftData = data.filter(point => point(attrIndex) < attrValue)
val rightData = data.filter(point => point(attrIndex) >= attrValue)
// recursively build a tree
new IFInternalNode(
iTree(leftData, currentPathLength + 1, maxDepth, constantFeatures.clone(), featureIdxArr),
iTree(rightData, currentPathLength + 1, maxDepth, constantFeatures.clone(), featureIdxArr),
iTree(leftData, currentPathLength + 1, maxDepth, constantFeatures.clone(), featureIdxArr, random),
iTree(rightData, currentPathLength + 1, maxDepth, constantFeatures.clone(), featureIdxArr, random),
featureIdxArr(attrIndex), attrValue)
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment