Commit 94f8218c authored by Maria Karanasou's avatar Maria Karanasou
Browse files

minor performance updates + dataset persisting / unpersisting

parent fc0d428c
Pipeline #39 failed with stages
......@@ -85,14 +85,15 @@ class IForestModel (
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val numSamples = dataset.count()
val possibleMaxSamples =
if ($(maxSamples) > 1.0) $(maxSamples) else ($(maxSamples) * numSamples)
val possibleMaxSamples = if ($(maxSamples) > 1.0) $(maxSamples) else ($(maxSamples) * numSamples)
val normFactor = avgLength(possibleMaxSamples)
val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val bcastNormFactor = dataset.sparkSession.sparkContext.broadcast(normFactor)
// calculate anomaly score
val scoreUDF = udf { (features: Vector) => {
val normFactor = avgLength(possibleMaxSamples)
val avgPathLength = bcastModel.value.calAvgPathLength(features)
Math.pow(2, -avgPathLength / normFactor)
Math.pow(2, -avgPathLength / bcastNormFactor.value)
}
}
// append a score column
......@@ -100,15 +101,23 @@ class IForestModel (
if (threshold < 0) {
logger.info("threshold is not set, calculating the anomaly threshold according to param contamination..")
threshold = scoreDataset.stat.approxQuantile($(anomalyScoreCol),
Array(1 - $(contamination)), $(approxQuantileRelativeError))(0)
scoreDataset.persist()
logger.debug("Persisted scoreDataset..")
threshold = scoreDataset.stat.approxQuantile(
$(anomalyScoreCol),
Array(1 - $(contamination)),
$(approxQuantileRelativeError)
)(0)
}
// set anomaly instance label 1
val predictUDF = udf { (anomalyScore: Double) =>
if (anomalyScore > threshold) 1.0 else 0.0
}
scoreDataset.withColumn($(predictionCol), predictUDF(col($(anomalyScoreCol))))
dataset.unpersist()
scoreDataset.withColumn(
$(predictionCol),
when(
col($(anomalyScoreCol)) > lit(threshold), 1.0
).otherwise(0.0)
)
}
/**
......@@ -129,7 +138,7 @@ class IForestModel (
* @param ifNode Tree's root node.
* @param currentPathLength Current path length.
* @return Path length in this tree.
*/
*/
private def calPathLength(features: Vector,
ifNode: IFNode,
currentPathLength: Int): Double = ifNode match {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment