Commit a9e9148b authored by titicaca's avatar titicaca
Browse files

Fix for model saving and loading

parent bde17cdc
......@@ -92,6 +92,15 @@ val pipeline = new Pipeline().setStages(Array(indexer, assembler, iForest))
val model = pipeline.fit(dataset)
val predictions = model.transform(dataset)
// Save pipeline model
model.write.overwrite().save("/tmp/iforest.model")
// Load pipeline model
val loadedPipelineModel = PipelineModel.load("/tmp/iforest.model")
// Get loaded iforest model
val loadedIforestModel = loadedPipelineModel.stages(2).asInstanceOf[IForestModel]
println(s"The loaded iforest model has no summary: model.hasSummary = ${loadedIforestModel.hasSummary}")
val binaryMetrics = new BinaryClassificationMetrics(
predictions.select("prediction", "label").rdd.map {
case Row(label: Double, ground: Double) => (label, ground)
......@@ -105,6 +114,10 @@ println(s"The model's auc: ${binaryMetrics.areaUnderROC()}")
*Python API*
```python
from pyspark.ml.linalg import Vectors
import tempfile
spark = SparkSession \
.builder.master("local[*]") \
.appName("IForestExample") \
......
package org.apache.spark.examples.ml
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.ml.iforest.IForest
import org.apache.spark.ml.iforest.{IForest, IForestModel}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.sql.{Row, SparkSession}
......@@ -48,6 +48,15 @@ object IForestExample {
val model = pipeline.fit(dataset)
val predictions = model.transform(dataset)
// Save pipeline model
model.write.overwrite().save("/tmp/iforest.model")
// Load pipeline model
val loadedPipelineModel = PipelineModel.load("/tmp/iforest.model")
// Get loaded iforest model
val loadedIforestModel = loadedPipelineModel.stages(2).asInstanceOf[IForestModel]
println(s"The loaded iforest model has no summary: model.hasSummary = ${loadedIforestModel.hasSummary}")
val binaryMetrics = new BinaryClassificationMetrics(
predictions.select("prediction", "label").rdd.map {
case Row(label: Double, ground: Double) => (label, ground)
......
......@@ -752,8 +752,8 @@ trait IForestParams extends Params {
* Relative Error for Approximate Quantile (0 <= value <= 1), default is 0.
* @group param
*/
final val approxQuantileRelativeError: Param[Double] =
new Param[Double](parent = this, name ="approxQuantileRelativeError", doc = "relative error for approximate quantile")
final val approxQuantileRelativeError: DoubleParam =
new DoubleParam(parent = this, name ="approxQuantileRelativeError", doc = "relative error for approximate quantile")
/** @group setParam */
setDefault(approxQuantileRelativeError, value = 0d)
......
......@@ -159,45 +159,45 @@ class IForestSuite extends SparkFunSuite with MLlibTestSparkContext with Default
assert(model1.summary.anomalyScoreCol === model2.summary.anomalyScoreCol)
}
// Uncomment the codes below to run the test for model read/write
// test("read/write") {
// def checkTreeNodes(node: IFNode, node2: IFNode): Unit = {
// (node, node2) match {
// case (node: IFInternalNode, node2: IFInternalNode) =>
// assert(node.featureValue === node2.featureValue)
// assert(node.featureIndex === node2.featureIndex)
// checkTreeNodes(node.leftChild, node2.leftChild)
// checkTreeNodes(node.rightChild, node2.rightChild)
// case (node: IFLeafNode, node2: IFLeafNode) =>
// assert(node.numInstance === node2.numInstance)
// case _ =>
// throw new AssertionError("Found mismatched nodes")
// }
// }
// def checkModelData(model: IForestModel, model2: IForestModel): Unit = {
// val trees = model.trees
// val trees2 = model2.trees
// assert(trees.length === trees2.length)
// try {
// trees.zip(trees2).foreach { case (node, node2) =>
// checkTreeNodes(node, node2)
// }
// } catch {
// case ex: Exception => throw new AssertionError(
// "checkModelData failed since the two trees were not identical.\n"
// )
// }
// }
//
// val iforest = new IForest()
// testEstimatorAndModelReadWrite(
// iforest,
// dataset,
// IForestSuite.allParamSettings,
// IForestSuite.allParamSettings,
// checkModelData
// )
// }
// test for model read/write
test("read/write") {
def checkTreeNodes(node: IFNode, node2: IFNode): Unit = {
(node, node2) match {
case (node: IFInternalNode, node2: IFInternalNode) =>
assert(node.featureValue === node2.featureValue)
assert(node.featureIndex === node2.featureIndex)
checkTreeNodes(node.leftChild, node2.leftChild)
checkTreeNodes(node.rightChild, node2.rightChild)
case (node: IFLeafNode, node2: IFLeafNode) =>
assert(node.numInstance === node2.numInstance)
case _ =>
throw new AssertionError("Found mismatched nodes")
}
}
def checkModelData(model: IForestModel, model2: IForestModel): Unit = {
val trees = model.trees
val trees2 = model2.trees
assert(trees.length === trees2.length)
try {
trees.zip(trees2).foreach { case (node, node2) =>
checkTreeNodes(node, node2)
}
} catch {
case ex: Exception => throw new AssertionError(
"checkModelData failed since the two trees were not identical.\n"
)
}
}
val iforest = new IForest()
testEstimatorAndModelReadWrite(
iforest,
dataset,
IForestSuite.allParamSettings,
IForestSuite.allParamSettings,
checkModelData
)
}
test("boundary case") {
intercept[IllegalArgumentException] {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment