diff options
Diffstat (limited to 'sample-apps/blog-tutorial-shared/src/test/scala/com/yahoo/example/blog/CollaborativeFilteringTest.scala')
-rw-r--r-- | sample-apps/blog-tutorial-shared/src/test/scala/com/yahoo/example/blog/CollaborativeFilteringTest.scala | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/sample-apps/blog-tutorial-shared/src/test/scala/com/yahoo/example/blog/CollaborativeFilteringTest.scala b/sample-apps/blog-tutorial-shared/src/test/scala/com/yahoo/example/blog/CollaborativeFilteringTest.scala new file mode 100644 index 00000000000..c660b45630a --- /dev/null +++ b/sample-apps/blog-tutorial-shared/src/test/scala/com/yahoo/example/blog/CollaborativeFilteringTest.scala @@ -0,0 +1,66 @@ +package com.yahoo.example.blog + +import org.apache.spark.ml.recommendation.ALSModel +import org.apache.spark.sql.SparkSession +import org.scalatest.Matchers._ +import org.scalatest._ + +class CollaborativeFilteringTest extends FunSuite with BeforeAndAfter { + + var ss: SparkSession = _ + + before { + + ss = SparkSession + .builder() + .appName("Unit Test") + .master("local[*]") + .getOrCreate() + + } + + after { + ss.stop() + } + + test("run method returns a MatrixFactorizationModel with latent factors of size 10 to user and item") { + + val file_path = getClass.getResource("/trainingSetIndicesSample.txt") + + val cf = new CollaborativeFiltering(ss) + + val model = cf.run( + input_path = file_path.toString, + rank = 10, + numIterations = 10, + lambda = 0.01) + + model shouldBe a [ALSModel] + + val product_feature_array = model.itemFactors.first().getSeq(1) + assertResult(10){product_feature_array.length} + + val user_feature_array = model.userFactors.first().getSeq(1) + assertResult(10){user_feature_array.length} + + } + + test("run_pipeline method returns a MatrixFactorizationModel with latent factors of size 10 to user and item") { + + val file_path = getClass.getResource("/trainingSetIndicesSample.txt") + + val cf = new CollaborativeFiltering(ss) + + val model = cf.run_pipeline(input_path = file_path.toString, numIterations = 10) + + model shouldBe a [ALSModel] + + val product_feature_array = model.itemFactors.first().getSeq(1) + assertResult(10){product_feature_array.length} + + val user_feature_array = model.userFactors.first().getSeq(1) + assertResult(10){user_feature_array.length} + + } + +} |