summaryrefslogtreecommitdiffstats
path: root/sample-apps/blog-tutorial-shared/src/main/scala/com/yahoo/example/blog/BlogRecommendationApp.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sample-apps/blog-tutorial-shared/src/main/scala/com/yahoo/example/blog/BlogRecommendationApp.scala')
-rw-r--r--sample-apps/blog-tutorial-shared/src/main/scala/com/yahoo/example/blog/BlogRecommendationApp.scala202
1 files changed, 202 insertions, 0 deletions
diff --git a/sample-apps/blog-tutorial-shared/src/main/scala/com/yahoo/example/blog/BlogRecommendationApp.scala b/sample-apps/blog-tutorial-shared/src/main/scala/com/yahoo/example/blog/BlogRecommendationApp.scala
new file mode 100644
index 00000000000..83d56718823
--- /dev/null
+++ b/sample-apps/blog-tutorial-shared/src/main/scala/com/yahoo/example/blog/BlogRecommendationApp.scala
@@ -0,0 +1,202 @@
+package com.yahoo.example.blog
+
+import org.apache.spark.sql.SparkSession
+
+object BlogRecommendationApp {
+ val usage = """
+ Usage: spark-submit \
+ | --class "BlogRecommendationApp" \
+ | --master local[4] \
+ | JAR_FILE
+ | --task task_command [TASK RELATED OPTIONS]
+
+ spark-submit \
+ | --class "BlogRecommendationApp" \
+ | --master local[4] \
+ | JAR_FILE
+ | --task collaborative_filtering
+ | --input_file path
+ | --rank value
+ | --numIterations value
+ | --lambda value
+ | --output_path path
+
+ spark-submit \
+ | --class "BlogRecommendationApp" \
+ | --master local[4] \
+ | JAR_FILE
+ | --task collaborative_filtering_cv
+ | --input_file path
+ | --numIterations value
+ | --output_path path
+ |
+
+ spark-submit \
+ | --class "BlogRecommendationApp" \
+ | --master local[4] \
+ | JAR_FILE
+ | --task split_set
+ | --input_file path
+ | --test_perc_stage1 value
+ | --test_perc_stage2 value
+ | --seed value
+ | --output_path path
+ """
+
+ private val COLLABORATIVE_FILTERING = "collaborative_filtering"
+ private val COLLABORATIVE_FILTERING_CV = "collaborative_filtering_cv"
+ private val SPLIT_SET_INTO_TRAIN_AND_TEST = "split_set"
+
+ type OptionMap = Map[Symbol, Any]
+
+ def main(args: Array[String]) {
+
+ val options = parseCommandLineOptions(args)
+ val task_name = options('task).toString
+
+ task_name match {
+ case COLLABORATIVE_FILTERING => CollaborativeFilteringExample(options)
+ case COLLABORATIVE_FILTERING_CV => CollaborativeFilteringCV(options)
+ case SPLIT_SET_INTO_TRAIN_AND_TEST => SplitSetIntoTrainingAndTestSets(options)
+ }
+
+ }
+
+ private def SplitSetIntoTrainingAndTestSets(options: OptionMap) = {
+
+ val spark = SparkSession
+ .builder()
+ .appName("Split Full Data Into Train and Test Sets")
+ .getOrCreate()
+
+ val splitter = new SplitFullSetIntoTrainAndTestSets(spark)
+
+ val sets = splitter.run(input_file_path = options('input_file).toString,
+ test_perc_stage1 = options('test_perc_stage1).toString.toDouble,
+ test_perc_stage2 = options('test_perc_stage2).toString.toDouble,
+ seed = options('seed).toString.toInt)
+
+ SplitFullSetIntoTrainAndTestSets.writeTrainAndTestSetsIndices(sets, options('output_path).toString)
+
+ }
+
+ private def CollaborativeFilteringExample(options: OptionMap) = {
+
+ // TODO: Check if output_path already exist
+
+ val spark = SparkSession
+ .builder()
+ .appName("Collaborative Filtering")
+ .getOrCreate()
+
+ val cf = new CollaborativeFiltering(spark)
+
+ val model = cf.run(
+ input_path = options('input_file).toString,
+ rank = options('rank).toString.toInt,
+ numIterations = options('num_iterations).toString.toInt,
+ lambda = options('lambda).toString.toDouble)
+
+ CollaborativeFiltering.writeFeaturesAsVespaTensorText(model, options('output_path).toString)
+
+ }
+
+ private def CollaborativeFilteringCV(options: OptionMap) = {
+
+ // TODO: Check if output_path already exist
+
+ val spark = SparkSession
+ .builder()
+ .appName("Collaborative Filtering CV")
+ .getOrCreate()
+
+ val cf = new CollaborativeFiltering(spark)
+
+ val model = cf.run_pipeline(
+ input_path = options('input_file).toString,
+ numIterations = options('num_iterations).toString.toInt)
+
+ CollaborativeFiltering.writeFeaturesAsVespaTensorText(model, options('output_path).toString)
+
+ }
+
+ private def parseCommandLineOptions(args: Array[String]): OptionMap = {
+
+ def findTask(list: List[String]) : String = {
+ list match {
+ case Nil => println("Please, define a valid task" + "\n" + usage)
+ sys.exit(1)
+ case "--task" :: value :: tail =>
+ value
+ case option :: tail => findTask(tail)
+ }
+ }
+
+ def ParseCollaborativeFilteringOptions(map : OptionMap, list: List[String]) : OptionMap = {
+ list match {
+ case Nil => map
+ case "--input_file" :: value :: tail =>
+ ParseCollaborativeFilteringOptions(map ++ Map('input_file -> value.toString), tail)
+ case "--rank" :: value :: tail =>
+ ParseCollaborativeFilteringOptions(map ++ Map('rank -> value.toInt), tail)
+ case "--numIterations" :: value :: tail =>
+ ParseCollaborativeFilteringOptions(map ++ Map('num_iterations -> value.toInt), tail)
+ case "--lambda" :: value :: tail =>
+ ParseCollaborativeFilteringOptions(map ++ Map('lambda -> value.toDouble), tail)
+ case "--output_path" :: value :: tail =>
+ ParseCollaborativeFilteringOptions(map ++ Map('output_path -> value.toString), tail)
+ case option :: tail =>
+ ParseCollaborativeFilteringOptions(map, tail)
+ }
+ }
+
+ def ParseCollaborativeFilteringCVOptions(map : OptionMap, list: List[String]) : OptionMap = {
+ list match {
+ case Nil => map
+ case "--input_file" :: value :: tail =>
+ ParseCollaborativeFilteringOptions(map ++ Map('input_file -> value.toString), tail)
+ case "--numIterations" :: value :: tail =>
+ ParseCollaborativeFilteringOptions(map ++ Map('num_iterations -> value.toInt), tail)
+ case "--output_path" :: value :: tail =>
+ ParseCollaborativeFilteringOptions(map ++ Map('output_path -> value.toString), tail)
+ case option :: tail =>
+ ParseCollaborativeFilteringOptions(map, tail)
+ }
+ }
+
+ def ParseSplitSetOptions(map : OptionMap, list: List[String]) : OptionMap = {
+ list match {
+ case Nil => map
+ case "--input_file" :: value :: tail =>
+ ParseSplitSetOptions(map ++ Map('input_file -> value.toString), tail)
+ case "--test_perc_stage1" :: value :: tail =>
+ ParseSplitSetOptions(map ++ Map('test_perc_stage1 -> value.toDouble), tail)
+ case "--test_perc_stage2" :: value :: tail =>
+ ParseSplitSetOptions(map ++ Map('test_perc_stage2 -> value.toDouble), tail)
+ case "--seed" :: value :: tail =>
+ ParseSplitSetOptions(map ++ Map('seed -> value.toInt), tail)
+ case "--output_path" :: value :: tail =>
+ ParseSplitSetOptions(map ++ Map('output_path -> value.toString), tail)
+ case option :: tail =>
+ ParseSplitSetOptions(map , tail)
+ }
+ }
+
+ if (args.length == 0) println(usage)
+ val arglist = args.toList
+
+ val task_name = findTask(arglist)
+
+ val options = task_name match {
+ case COLLABORATIVE_FILTERING => ParseCollaborativeFilteringOptions(Map('task -> task_name), arglist)
+ case COLLABORATIVE_FILTERING_CV => ParseCollaborativeFilteringCVOptions(Map('task -> task_name), arglist)
+ case SPLIT_SET_INTO_TRAIN_AND_TEST => ParseSplitSetOptions(Map('task -> task_name), arglist)
+ }
+
+ options
+
+ }
+
+}
+
+