Commit 81ca5085 authored by titicaca's avatar titicaca
Browse files

upgrade to spark v2.4.0

parent 4848fe69
......@@ -6,7 +6,7 @@
<groupId>org.apache.spark.ml</groupId>
<artifactId>spark-iforest</artifactId>
<version>2.3.0</version>
<version>2.4.0</version>
<properties>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
......@@ -18,7 +18,7 @@
<log4j.version>1.2.17</log4j.version>
<skipTests>false</skipTests>
<maven.version>3.3.9</maven.version>
<spark.version>2.3.0</spark.version>
<spark.version>2.4.0</spark.version>
</properties>
<dependencies>
......
......@@ -11,6 +11,7 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._
......@@ -287,7 +288,7 @@ object IForestModel extends MLReadable[IForestModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val trees = loadTreeNodes(path, sparkSession)
val model = new IForestModel(metadata.uid, trees)
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
......@@ -471,15 +472,15 @@ class IForest (
* use VectorAssembler to generate a feature column.
* @return trained iforest model with an array of each tree's root node.
*/
override def fit(dataset: Dataset[_]): IForestModel = {
override def fit(dataset: Dataset[_]): IForestModel = instrumented { instr =>
transformSchema(dataset.schema, logging = true)
val rddPerTree = splitData(dataset)
// add Instrumentation instance
val instr = Instrumentation.create(this, rddPerTree)
instr.logParams(numTrees, maxSamples, maxFeatures, maxDepth, contamination,
bootstrap, seed, featuresCol, predictionCol, labelCol)
instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(this, numTrees, maxSamples, maxFeatures, maxDepth, contamination,
bootstrap, seed, featuresCol, predictionCol, labelCol)
// Each iTree of the iForest will be built on parallel and collected in the driver.
// Approximate memory usage for iForest model is calculated, a warning will be raised if iForest is too large.
......@@ -520,7 +521,6 @@ class IForest (
predictions, $(featuresCol), $(predictionCol), $(anomalyScoreCol)
)
model.setSummary(Some(summary))
instr.logSuccess(model)
model
}
......
......@@ -36,7 +36,7 @@ class IForestSuite extends SparkFunSuite with MLlibTestSparkContext with Default
.setNumTrees(10)
.setMaxSamples(10)
.setMaxFeatures(10)
.setMaxDepth(10)
.setMaxDepth(4)
.setContamination(0.01)
.setBootstrap(true)
.setSeed(123L)
......@@ -48,7 +48,7 @@ class IForestSuite extends SparkFunSuite with MLlibTestSparkContext with Default
assert(iforest.getNumTrees === 10)
assert(iforest.getMaxSamples === 10)
assert(iforest.getMaxFeatures === 10)
assert(iforest.getMaxDepth === 10)
assert(iforest.getMaxDepth === 4)
assert(iforest.getContamination === 0.01)
assert(iforest.getBootstrap)
assert(iforest.getSeed === 123L)
......@@ -62,6 +62,7 @@ class IForestSuite extends SparkFunSuite with MLlibTestSparkContext with Default
// test with bootsrap
val iforest1 = new IForest()
.setNumTrees(2)
.setMaxDepth(4)
.setMaxSamples(1.0)
.setMaxFeatures(1.0)
.setBootstrap(false)
......@@ -72,6 +73,7 @@ class IForestSuite extends SparkFunSuite with MLlibTestSparkContext with Default
// test without bootstrap
val iforest2 = new IForest()
.setNumTrees(2)
.setMaxDepth(4)
.setMaxSamples(1.0)
.setMaxFeatures(1.0)
.setBootstrap(true)
......@@ -97,6 +99,7 @@ class IForestSuite extends SparkFunSuite with MLlibTestSparkContext with Default
val anomalyScoreName = "test_anomalyScore"
val iforest = new IForest()
.setNumTrees(10)
.setMaxDepth(4)
.setPredictionCol(predictionColName)
.setAnomalyScoreCol(anomalyScoreName)
.setContamination(0.2)
......@@ -117,7 +120,7 @@ class IForestSuite extends SparkFunSuite with MLlibTestSparkContext with Default
}
test("copy estimator and model") {
val iforest1 = new IForest()
val iforest1 = new IForest().setMaxDepth(4)
val iforest2 = iforest1.copy(ParamMap.empty)
iforest1.params.foreach { p =>
if (iforest1.isDefined(p)) {
......@@ -255,7 +258,7 @@ object IForestSuite {
"numTrees" -> 1,
"maxSamples" -> 1.0,
"maxFeatures" -> 1.0,
"maxDepth" -> 5,
"maxDepth" -> 3,
"contamination" -> 0.2,
"bootstrap" -> false
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment