Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Maria Karanasou
spark-iforest
Commits
81ca5085
Commit
81ca5085
authored
Dec 27, 2018
by
titicaca
Browse files
upgrade to spark v2.4.0
parent
4848fe69
Changes
3
Hide whitespace changes
Inline
Side-by-side
pom.xml
View file @
81ca5085
...
...
@@ -6,7 +6,7 @@
<groupId>
org.apache.spark.ml
</groupId>
<artifactId>
spark-iforest
</artifactId>
<version>
2.
3
.0
</version>
<version>
2.
4
.0
</version>
<properties>
<project.reporting.outputEncoding>
UTF-8
</project.reporting.outputEncoding>
...
...
@@ -18,7 +18,7 @@
<log4j.version>
1.2.17
</log4j.version>
<skipTests>
false
</skipTests>
<maven.version>
3.3.9
</maven.version>
<spark.version>
2.
3
.0
</spark.version>
<spark.version>
2.
4
.0
</spark.version>
</properties>
<dependencies>
...
...
src/main/scala/org/apache/spark/ml/iforest/IForest.scala
View file @
81ca5085
...
...
@@ -11,6 +11,7 @@ import org.apache.spark.ml.{Estimator, Model}
import
org.apache.spark.ml.linalg._
import
org.apache.spark.ml.param._
import
org.apache.spark.ml.util._
import
org.apache.spark.ml.util.Instrumentation.instrumented
import
org.apache.spark.rdd.RDD
import
org.apache.spark.sql.
{
DataFrame
,
Dataset
,
Row
,
SparkSession
}
import
org.apache.spark.sql.functions._
...
...
@@ -287,7 +288,7 @@ object IForestModel extends MLReadable[IForestModel] {
val
metadata
=
DefaultParamsReader
.
loadMetadata
(
path
,
sc
,
className
)
val
trees
=
loadTreeNodes
(
path
,
sparkSession
)
val
model
=
new
IForestModel
(
metadata
.
uid
,
trees
)
DefaultParamsReader
.
getAndSetParams
(
model
,
metadata
)
metadata
.
getAndSetParams
(
model
)
model
}
}
...
...
@@ -471,15 +472,15 @@ class IForest (
* use VectorAssembler to generate a feature column.
* @return trained iforest model with an array of each tree's root node.
*/
override
def
fit
(
dataset
:
Dataset
[
_
])
:
IForestModel
=
{
override
def
fit
(
dataset
:
Dataset
[
_
])
:
IForestModel
=
instrumented
{
instr
=>
transformSchema
(
dataset
.
schema
,
logging
=
true
)
val
rddPerTree
=
splitData
(
dataset
)
// add Instrumentation instance
val
instr
=
Instrumentation
.
create
(
this
,
rddPerTree
)
instr
.
logParams
(
numTrees
,
maxSamples
,
maxFeatures
,
maxDepth
,
contamination
,
bootstrap
,
seed
,
featuresCol
,
predictionCol
,
labelCol
)
instr
.
logPipelineStage
(
this
)
instr
.
logDataset
(
dataset
)
instr
.
logParams
(
this
,
numTrees
,
maxSamples
,
maxFeatures
,
maxDepth
,
contamination
,
bootstrap
,
seed
,
featuresCol
,
predictionCol
,
labelCol
)
// Each iTree of the iForest will be built on parallel and collected in the driver.
// Approximate memory usage for iForest model is calculated, a warning will be raised if iForest is too large.
...
...
@@ -520,7 +521,6 @@ class IForest (
predictions
,
$
(
featuresCol
),
$
(
predictionCol
),
$
(
anomalyScoreCol
)
)
model
.
setSummary
(
Some
(
summary
))
instr
.
logSuccess
(
model
)
model
}
...
...
src/test/scala/org/apache/spark/ml/iforest/IForestSuite.scala
View file @
81ca5085
...
...
@@ -36,7 +36,7 @@ class IForestSuite extends SparkFunSuite with MLlibTestSparkContext with Default
.
setNumTrees
(
10
)
.
setMaxSamples
(
10
)
.
setMaxFeatures
(
10
)
.
setMaxDepth
(
10
)
.
setMaxDepth
(
4
)
.
setContamination
(
0.01
)
.
setBootstrap
(
true
)
.
setSeed
(
123L
)
...
...
@@ -48,7 +48,7 @@ class IForestSuite extends SparkFunSuite with MLlibTestSparkContext with Default
assert
(
iforest
.
getNumTrees
===
10
)
assert
(
iforest
.
getMaxSamples
===
10
)
assert
(
iforest
.
getMaxFeatures
===
10
)
assert
(
iforest
.
getMaxDepth
===
10
)
assert
(
iforest
.
getMaxDepth
===
4
)
assert
(
iforest
.
getContamination
===
0.01
)
assert
(
iforest
.
getBootstrap
)
assert
(
iforest
.
getSeed
===
123L
)
...
...
@@ -62,6 +62,7 @@ class IForestSuite extends SparkFunSuite with MLlibTestSparkContext with Default
// test with bootsrap
val
iforest1
=
new
IForest
()
.
setNumTrees
(
2
)
.
setMaxDepth
(
4
)
.
setMaxSamples
(
1.0
)
.
setMaxFeatures
(
1.0
)
.
setBootstrap
(
false
)
...
...
@@ -72,6 +73,7 @@ class IForestSuite extends SparkFunSuite with MLlibTestSparkContext with Default
// test without bootstrap
val
iforest2
=
new
IForest
()
.
setNumTrees
(
2
)
.
setMaxDepth
(
4
)
.
setMaxSamples
(
1.0
)
.
setMaxFeatures
(
1.0
)
.
setBootstrap
(
true
)
...
...
@@ -97,6 +99,7 @@ class IForestSuite extends SparkFunSuite with MLlibTestSparkContext with Default
val
anomalyScoreName
=
"test_anomalyScore"
val
iforest
=
new
IForest
()
.
setNumTrees
(
10
)
.
setMaxDepth
(
4
)
.
setPredictionCol
(
predictionColName
)
.
setAnomalyScoreCol
(
anomalyScoreName
)
.
setContamination
(
0.2
)
...
...
@@ -117,7 +120,7 @@ class IForestSuite extends SparkFunSuite with MLlibTestSparkContext with Default
}
test
(
"copy estimator and model"
)
{
val
iforest1
=
new
IForest
()
val
iforest1
=
new
IForest
()
.
setMaxDepth
(
4
)
val
iforest2
=
iforest1
.
copy
(
ParamMap
.
empty
)
iforest1
.
params
.
foreach
{
p
=>
if
(
iforest1
.
isDefined
(
p
))
{
...
...
@@ -255,7 +258,7 @@ object IForestSuite {
"numTrees"
->
1
,
"maxSamples"
->
1.0
,
"maxFeatures"
->
1.0
,
"maxDepth"
->
5
,
"maxDepth"
->
3
,
"contamination"
->
0.2
,
"bootstrap"
->
false
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment