Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Maria Karanasou
spark-iforest
Commits
a9e9148b
Commit
a9e9148b
authored
Jul 02, 2019
by
titicaca
Browse files
Fix for model saving and loading
parent
bde17cdc
Changes
4
Hide whitespace changes
Inline
Side-by-side
README.md
View file @
a9e9148b
...
...
@@ -92,6 +92,15 @@ val pipeline = new Pipeline().setStages(Array(indexer, assembler, iForest))
val
model
=
pipeline
.
fit
(
dataset
)
val
predictions
=
model
.
transform
(
dataset
)
// Save pipeline model
model
.
write
.
overwrite
().
save
(
"/tmp/iforest.model"
)
// Load pipeline model
val
loadedPipelineModel
=
PipelineModel
.
load
(
"/tmp/iforest.model"
)
// Get loaded iforest model
val
loadedIforestModel
=
loadedPipelineModel
.
stages
(
2
).
asInstanceOf
[
IForestModel
]
println
(
s
"The loaded iforest model has no summary: model.hasSummary = ${loadedIforestModel.hasSummary}"
)
val
binaryMetrics
=
new
BinaryClassificationMetrics
(
predictions
.
select
(
"prediction"
,
"label"
).
rdd
.
map
{
case
Row
(
label
:
Double
,
ground
:
Double
)
=>
(
label
,
ground
)
...
...
@@ -105,6 +114,10 @@ println(s"The model's auc: ${binaryMetrics.areaUnderROC()}")
*Python API*
```
python
from
pyspark.ml.linalg
import
Vectors
import
tempfile
spark
=
SparkSession
\
.
builder
.
master
(
"local[*]"
)
\
.
appName
(
"IForestExample"
)
\
...
...
src/main/scala/org/apache/spark/examples/ml/IForestExample.scala
View file @
a9e9148b
package
org.apache.spark.examples.ml
import
org.apache.spark.ml.Pipeline
import
org.apache.spark.ml.
{
Pipeline
,
PipelineModel
}
import
org.apache.spark.ml.feature.
{
StringIndexer
,
VectorAssembler
}
import
org.apache.spark.ml.iforest.IForest
import
org.apache.spark.ml.iforest.
{
IForest
,
IForestModel
}
import
org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import
org.apache.spark.sql.
{
Row
,
SparkSession
}
...
...
@@ -48,6 +48,15 @@ object IForestExample {
val
model
=
pipeline
.
fit
(
dataset
)
val
predictions
=
model
.
transform
(
dataset
)
// Save pipeline model
model
.
write
.
overwrite
().
save
(
"/tmp/iforest.model"
)
// Load pipeline model
val
loadedPipelineModel
=
PipelineModel
.
load
(
"/tmp/iforest.model"
)
// Get loaded iforest model
val
loadedIforestModel
=
loadedPipelineModel
.
stages
(
2
).
asInstanceOf
[
IForestModel
]
println
(
s
"The loaded iforest model has no summary: model.hasSummary = ${loadedIforestModel.hasSummary}"
)
val
binaryMetrics
=
new
BinaryClassificationMetrics
(
predictions
.
select
(
"prediction"
,
"label"
).
rdd
.
map
{
case
Row
(
label
:
Double
,
ground
:
Double
)
=>
(
label
,
ground
)
...
...
src/main/scala/org/apache/spark/ml/iforest/IForest.scala
View file @
a9e9148b
...
...
@@ -752,8 +752,8 @@ trait IForestParams extends Params {
* Relative Error for Approximate Quantile (0 <= value <= 1), default is 0.
* @group param
*/
final
val
approxQuantileRelativeError
:
Param
[
Double
]
=
new
Param
[
Double
]
(
parent
=
this
,
name
=
"approxQuantileRelativeError"
,
doc
=
"relative error for approximate quantile"
)
final
val
approxQuantileRelativeError
:
Double
Param
=
new
Double
Param
(
parent
=
this
,
name
=
"approxQuantileRelativeError"
,
doc
=
"relative error for approximate quantile"
)
/** @group setParam */
setDefault
(
approxQuantileRelativeError
,
value
=
0d
)
...
...
src/test/scala/org/apache/spark/ml/iforest/IForestSuite.scala
View file @
a9e9148b
...
...
@@ -159,45 +159,45 @@ class IForestSuite extends SparkFunSuite with MLlibTestSparkContext with Default
assert
(
model1
.
summary
.
anomalyScoreCol
===
model2
.
summary
.
anomalyScoreCol
)
}
//
Uncomment the codes below to run the
test for model read/write
//
test("read/write") {
//
def checkTreeNodes(node: IFNode, node2: IFNode): Unit = {
//
(node, node2) match {
//
case (node: IFInternalNode, node2: IFInternalNode) =>
//
assert(node.featureValue === node2.featureValue)
//
assert(node.featureIndex === node2.featureIndex)
//
checkTreeNodes(node.leftChild, node2.leftChild)
//
checkTreeNodes(node.rightChild, node2.rightChild)
//
case (node: IFLeafNode, node2: IFLeafNode) =>
//
assert(node.numInstance === node2.numInstance)
//
case _ =>
//
throw new AssertionError("Found mismatched nodes")
//
}
//
}
//
def checkModelData(model: IForestModel, model2: IForestModel): Unit = {
//
val trees = model.trees
//
val trees2 = model2.trees
//
assert(trees.length === trees2.length)
//
try {
//
trees.zip(trees2).foreach { case (node, node2) =>
//
checkTreeNodes(node, node2)
//
}
//
} catch {
//
case ex: Exception => throw new AssertionError(
//
"checkModelData failed since the two trees were not identical.\n"
//
)
//
}
//
}
//
//
val iforest = new IForest()
//
testEstimatorAndModelReadWrite(
//
iforest,
//
dataset,
//
IForestSuite.allParamSettings,
//
IForestSuite.allParamSettings,
//
checkModelData
//
)
//
}
// test for model read/write
test
(
"read/write"
)
{
def
checkTreeNodes
(
node
:
IFNode
,
node2
:
IFNode
)
:
Unit
=
{
(
node
,
node2
)
match
{
case
(
node
:
IFInternalNode
,
node2
:
IFInternalNode
)
=>
assert
(
node
.
featureValue
===
node2
.
featureValue
)
assert
(
node
.
featureIndex
===
node2
.
featureIndex
)
checkTreeNodes
(
node
.
leftChild
,
node2
.
leftChild
)
checkTreeNodes
(
node
.
rightChild
,
node2
.
rightChild
)
case
(
node
:
IFLeafNode
,
node2
:
IFLeafNode
)
=>
assert
(
node
.
numInstance
===
node2
.
numInstance
)
case
_
=>
throw
new
AssertionError
(
"Found mismatched nodes"
)
}
}
def
checkModelData
(
model
:
IForestModel
,
model2
:
IForestModel
)
:
Unit
=
{
val
trees
=
model
.
trees
val
trees2
=
model2
.
trees
assert
(
trees
.
length
===
trees2
.
length
)
try
{
trees
.
zip
(
trees2
).
foreach
{
case
(
node
,
node2
)
=>
checkTreeNodes
(
node
,
node2
)
}
}
catch
{
case
ex
:
Exception
=>
throw
new
AssertionError
(
"checkModelData failed since the two trees were not identical.\n"
)
}
}
val
iforest
=
new
IForest
()
testEstimatorAndModelReadWrite
(
iforest
,
dataset
,
IForestSuite
.
allParamSettings
,
IForestSuite
.
allParamSettings
,
checkModelData
)
}
test
(
"boundary case"
)
{
intercept
[
IllegalArgumentException
]
{
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment