summaryrefslogtreecommitdiffstats
path: root/sample-apps/blog-tutorial-shared/src/test/scala/com/yahoo/example/blog/CollaborativeFilteringTest.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sample-apps/blog-tutorial-shared/src/test/scala/com/yahoo/example/blog/CollaborativeFilteringTest.scala')
-rw-r--r--sample-apps/blog-tutorial-shared/src/test/scala/com/yahoo/example/blog/CollaborativeFilteringTest.scala66
1 files changed, 66 insertions, 0 deletions
diff --git a/sample-apps/blog-tutorial-shared/src/test/scala/com/yahoo/example/blog/CollaborativeFilteringTest.scala b/sample-apps/blog-tutorial-shared/src/test/scala/com/yahoo/example/blog/CollaborativeFilteringTest.scala
new file mode 100644
index 00000000000..c660b45630a
--- /dev/null
+++ b/sample-apps/blog-tutorial-shared/src/test/scala/com/yahoo/example/blog/CollaborativeFilteringTest.scala
@@ -0,0 +1,66 @@
+package com.yahoo.example.blog
+
+import org.apache.spark.ml.recommendation.ALSModel
+import org.apache.spark.sql.SparkSession
+import org.scalatest.Matchers._
+import org.scalatest._
+
+class CollaborativeFilteringTest extends FunSuite with BeforeAndAfter {
+
+ var ss: SparkSession = _
+
+ before {
+
+ ss = SparkSession
+ .builder()
+ .appName("Unit Test")
+ .master("local[*]")
+ .getOrCreate()
+
+ }
+
+ after {
+ ss.stop()
+ }
+
+ test("run method returns a MatrixFactorizationModel with latent factors of size 10 to user and item") {
+
+ val file_path = getClass.getResource("/trainingSetIndicesSample.txt")
+
+ val cf = new CollaborativeFiltering(ss)
+
+ val model = cf.run(
+ input_path = file_path.toString,
+ rank = 10,
+ numIterations = 10,
+ lambda = 0.01)
+
+ model shouldBe a [ALSModel]
+
+ val product_feature_array = model.itemFactors.first().getSeq(1)
+ assertResult(10){product_feature_array.length}
+
+ val user_feature_array = model.userFactors.first().getSeq(1)
+ assertResult(10){user_feature_array.length}
+
+ }
+
+ test("run_pipeline method returns a MatrixFactorizationModel with latent factors of size 10 to user and item") {
+
+ val file_path = getClass.getResource("/trainingSetIndicesSample.txt")
+
+ val cf = new CollaborativeFiltering(ss)
+
+ val model = cf.run_pipeline(input_path = file_path.toString, numIterations = 10)
+
+ model shouldBe a [ALSModel]
+
+ val product_feature_array = model.itemFactors.first().getSeq(1)
+ assertResult(10){product_feature_array.length}
+
+ val user_feature_array = model.userFactors.first().getSeq(1)
+ assertResult(10){user_feature_array.length}
+
+ }
+
+}