diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2017-06-20 10:27:18 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-06-20 10:27:18 +0200 |
commit | 552782a9b7da4e85721c05a62b9b8c3fed5c6a15 (patch) | |
tree | fff41f418d74e507769c20951074a4bc1e5c34bc /searchlib/src | |
parent | 60162a9ca2efd83c95b5b06e2944e75c379c5223 (diff) | |
parent | 541a1ad6a844af5a2a283708fb5a2a87febafb07 (diff) |
Merge pull request #2826 from yahoo/bratseth/remove-ga
Remove code not necessary for Vespa
Diffstat (limited to 'searchlib/src')
28 files changed, 0 insertions, 2193 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/CaseList.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/CaseList.java deleted file mode 100644 index 13343029ebc..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/CaseList.java +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga; - -import java.util.List; - -/** - * A producer of a list of cases for function training. - * - * @author bratseth - */ -public interface CaseList { - - public List<TrainingSet.Case> cases(); - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Evolvable.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Evolvable.java deleted file mode 100644 index bbd3844d036..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Evolvable.java +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; - -import java.util.List; - -/** - * An entity which may evolve over time - * - * @author bratseth - */ -public abstract class Evolvable implements Comparable<Evolvable> { - - public abstract Evolvable makeSuccessor(int memberNumber, List<RankingExpression> genepool, TrainingEnvironment environment); - - public abstract RankingExpression getGenepool(); - - @Override - public int compareTo(Evolvable other) { - return -Double.compare(getFitness(), other.getFitness()); - } - - public abstract double getFitness(); - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Individual.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Individual.java deleted file mode 100644 index e42636c00b2..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Individual.java +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; - -import java.util.Collections; -import java.util.List; - -/** - * An individual in an evolving population - a genome with a fitness score. - * Individuals are comparable by decreasing fitness. - * <p> - * As we are training ranking expressions, the genome, here, is the ranking expression. - * - * @author bratseth - */ -public class Individual extends Evolvable { - - private final RankingExpression genome; - private final TrainingSet trainingSet; - private final double fitness; - - public Individual(RankingExpression genome, TrainingSet trainingSet) { - this.genome = genome; - this.trainingSet = trainingSet; - this.fitness = trainingSet.evaluate(genome); - } - - public RankingExpression getGenome() { return genome; } - - public double calculateAverageError() { - return trainingSet.calculateAverageError(genome); - } - - public double calculateAverageErrorPercentage() { - return trainingSet.calculateAverageErrorPercentage(genome); - } - - @Override - public double getFitness() { return fitness; } - - @Override - public Individual makeSuccessor(int memberNumber, List<RankingExpression> genepool, TrainingEnvironment environment) { - return new Individual(environment.recombiner().recombine(genome, genepool), trainingSet); - } - - @Override - public RankingExpression getGenepool() { - return genome; - } - - @Override - public String toString() { - return toSomewhatShortString() + ", expression: " + genome; - } - - /** Returns a shorter string describing this (not including the expression */ - public String toSomewhatShortString() { - return "Error % " + calculateAverageErrorPercentage() + - " average error " + calculateAverageError() + - " fitness " + getFitness(); - } - - /** Returns a shorter string describing this (not including the expression */ - public String toShortString() { - return "Error: " + calculateAverageErrorPercentage() + " %"; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/KeyboardChecker.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/KeyboardChecker.java deleted file mode 100644 index 58e569bef33..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/KeyboardChecker.java +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga; - -import java.awt.KeyEventDispatcher; -import java.awt.KeyboardFocusManager; -import java.awt.event.KeyEvent; - -/** - * TODO - * - * @author bratseth - */ -public class KeyboardChecker { - - private static boolean qPressed = false; - - private final Object lock = new Object(); - - public KeyboardChecker() { - KeyboardFocusManager.getCurrentKeyboardFocusManager().addKeyEventDispatcher(new KeyEventDispatcher() { - - @Override - public boolean dispatchKeyEvent(KeyEvent ke) { - synchronized (lock) { - switch (ke.getID()) { - case KeyEvent.KEY_PRESSED: - if (ke.getKeyCode() == KeyEvent.VK_Q) { - qPressed = true; - } - break; - - case KeyEvent.KEY_RELEASED: - if (ke.getKeyCode() == KeyEvent.VK_Q) { - qPressed = false; - } - break; - } - return false; - } - } - }); - } - - public boolean isQPressed() { - synchronized (lock) { - return qPressed; - } - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Main.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Main.java deleted file mode 100644 index 204c03b92b6..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Main.java +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga; - -import com.yahoo.io.IOUtils; -import com.yahoo.searchlib.mlr.ga.caselist.FileCaseList; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.parser.ParseException; - -import java.io.BufferedReader; -import java.io.FileReader; -import java.io.IOException; - -/** - * Command line runner for training sessions - * - * @author bratseth - */ -/* -TODO: Switch order of generation and sequence in names -TODO: Output fitness improvement on each step (esp useful for species evolution) -TODO: Detect local optima (no improvement for n rounds) and stop early -TODO: Split into training and validation sets - */ -public class Main { - - public Main(String[] args, Tracker tracker) { - if (args.length < 1 || args[0].trim().equals("help")) { - System.out.println( - "Finds a ranking expression matching a training set given as a case file.\n" + - "Run until the expression seems good enough.\n" + - "Usage: ga <case-file> - \n" + - " where case-file is a file containing case lines on the form \n" + - " targetValue, argument1:value1, ...\n" + - " (comment lines starting by # are also permitted)\n"); - return; - } - - TrainingParameters parameters = new TrainingParameters(); - //parameters.setAllowConditions(false); - parameters.setErrorIsRelative(false); - parameters.setInitialSpeciesSize(40); - parameters.setSpeciesLifespan(100); - parameters.setExcludeFeatures("F7,F9,F10,F11,F12,F13,F14,F15,F16,F17,F18,F19,F21,F23,F24,F25,F26,F27,F29,F30,F32,F33,F34,F35,F36,F37,F38,F39,F40,F41,F42,F44,F46,F47,F48,F49,F50,F52,F53,F55,F56,F57,F58,F59,F60,F61,F62,F63,F64,F65,F67,F69,F70,F71,F72,F73,F75,F76,F78,F79,F80,F81,F82,F83,F84,F85,F86,F87,F88,F90,F92,F93,F94,F95,F96,F98,F99,F100,F101,F102,F103,F104,F105,F106,F107,F108,F109,F66,F89,F110"); - //parameters.setInitialSpeciesSize(20); - - String caseFile = args[0]; - TrainingSet trainingSet = new TrainingSet(FileCaseList.create(caseFile, parameters), parameters); - Trainer trainer = new Trainer(trainingSet); - - if (args.length > 1) { // Evaluate given expression - try { - Individual given = new Individual(new RankingExpression(new BufferedReader(new FileReader(args[1]))), trainingSet); - System.out.println("Error in '" + args[1] + "': error % " + given.calculateAverageErrorPercentage() + - " average error " + given.calculateAverageError() + - " fitness " + given.getFitness()); - } - catch (IOException | ParseException e) { - throw new IllegalArgumentException("Could not evaluate expression in argument 2", e); - } - } - else { // Train expression - // TODO: Move system outs to tracker - System.out.println("Learning ..."); - RankingExpression learntExpression = trainer.train(parameters, tracker); - System.out.println("Learnt expression: " + learntExpression); - } - } - - public static void main(String[] args) { - new Main(args, new PrintingTracker(10, 0, 1)); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Population.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Population.java deleted file mode 100644 index 8aa47db6d09..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Population.java +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -/** - * A collection of evolvables - * - * @author bratseth - */ -public class Population { - - /** The current members of this population, always sorted by decreasing fitness */ - private List<Evolvable> members; - - public Population(List<Evolvable> initialMembers) { - members = new ArrayList<>(initialMembers); - Collections.sort(members); - } - - /** Returns the most fit member of this population (never null) */ - public Evolvable best() { - return members.get(0); - } - - /** Returns the members of this population as an unmodifiable list sorted by decreasing fitness*/ - public List<Evolvable> members() { return Collections.unmodifiableList(members); } - - public void evolve(int generation, TrainingEnvironment environment) { - TrainingParameters p = environment.parameters(); - int generationSize = p.getInitialSpeciesSize() - - (int)Math.round((p.getInitialSpeciesSize() - p.getFinalSpeciesSize()) * generation/p.getSpeciesLifespan()); - members = breed(members, generationSize * p.getGenerationCandidatesFactor(), environment); - Collections.sort(members); - members = members.subList(0, Math.min(generationSize, members.size())); - } - - private List<Evolvable> breed(List<Evolvable> members, int offspringCount, TrainingEnvironment environment) { - List<Evolvable> offspring = new ArrayList<>(offspringCount); // TODO: Can we do this inline and keep the list forever (and then also the immutable view) - offspring.add(members.get(0)); // keep the best as-is - List<RankingExpression> genePool = collectGenepool(members); - for (int i = 0; i < offspringCount - 1; i++) { - Evolvable child = members.get(i % members.size()).makeSuccessor(i, genePool, environment); - offspring.add(child); - } - return offspring; - } - - private List<RankingExpression> collectGenepool(List<Evolvable> members) { - List<RankingExpression> genepool = new ArrayList<>(); - for (Evolvable member : members) - genepool.add(member.getGenepool()); - return genepool; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/PrintingTracker.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/PrintingTracker.java deleted file mode 100644 index 1bd7980bc3f..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/PrintingTracker.java +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.yolean.Exceptions; - -import java.util.List; - -/** - * A tracker which prints a summary of training events to standard out - * - * @author bratseth - */ -public class PrintingTracker implements Tracker { - - private final int iterationEvery; - private final int survivorsEvery; - private final int printSpeciesCreationLevel; - private final int printSpeciesCompletionLevel; - - public PrintingTracker() { - this(0, 1); - } - - public PrintingTracker(int printSpeciesCreationLevel, int printSpeciesCompletionLevel) { - this(Integer.MAX_VALUE, Integer.MAX_VALUE, printSpeciesCreationLevel, printSpeciesCompletionLevel); - } - - public PrintingTracker(int iterationEvery, int printSpeciesCreationLevel, int printSpeciesCompletionLevel) { - this(iterationEvery, Integer.MAX_VALUE, printSpeciesCreationLevel, printSpeciesCompletionLevel); - } - - public PrintingTracker(int iterationEvery, int survivorsEvery, int printSpeciesCreationLevel, int printSpeciesCompletionLevel) { - this.iterationEvery = iterationEvery; - this.survivorsEvery = survivorsEvery; - this.printSpeciesCreationLevel = printSpeciesCreationLevel; - this.printSpeciesCompletionLevel = printSpeciesCompletionLevel; - } - - @Override - public void newSpecies(Species predecessor, int initialSize, List<RankingExpression> genePool) { - if (predecessor.name().level() > printSpeciesCreationLevel) return; - System.out.println(spaces(predecessor.name().level()*2) + "Creating new species of size " + initialSize + " and a gene pool of size " + genePool.size() + " from predecessor " + predecessor); - } - - @Override - public void newSpeciesCreated(Species species) { - if (species.name().level() > printSpeciesCreationLevel) return; - System.out.println(spaces(species.name().level()*2) + "Created and will now evolve " + species); - } - - @Override - public void speciesCompleted(Species species) { - if (species.name().level() > printSpeciesCompletionLevel) return; - System.out.println(spaces(species.name().level()*2) + "--> Evolution completed for " + species); - } - - /** Called each time a species (or super-species) have completed one generation */ - @Override - public void iteration(Species species, int generation) { - try { - new RankingExpression(species.bestIndividual().getGenome().toString()); - } - catch (Exception e) { - System.err.println("ERROR: " + Exceptions.toMessageString(e) + ": " + species.bestIndividual().getGenome()); - } - - if ( (generation % iterationEvery) == 0) - System.out.println(spaces(species.name().level()*2) + "Gen " + generation + " of " + species); - - if ( (generation % survivorsEvery) == 0) - printPopulation(species.name().level(), species.population().members()); - } - - @Override - public void result(Evolvable winner) { - System.out.println("Learnt expression: " + winner); - } - - private String spaces(int spaces) { - return " ".substring(0,spaces); - } - - private void printPopulation(int level, List<Evolvable> survivors) { - if (survivors.size()<=1) return; - System.out.println(" Population:"); - for (Evolvable individual : survivors) - System.out.println(spaces(level*2) + " " + individual); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/RankingExpressionCaseList.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/RankingExpressionCaseList.java deleted file mode 100644 index 596db4cfd42..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/RankingExpressionCaseList.java +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga; - -import com.yahoo.searchlib.mlr.ga.CaseList; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.mlr.ga.TrainingSet; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -/** - * Produces a list of training cases (argument and target value pairs) - * from a Ranking Expression. - * Useful for testing. - * - * @author bratseth - */ -public class RankingExpressionCaseList implements CaseList { - - private final List<TrainingSet.Case> cases = new ArrayList<TrainingSet.Case>(); - - public RankingExpressionCaseList(List<Context> arguments, RankingExpression targetFunction) { - for (Context argument : arguments) - cases.add(new TrainingSet.Case(argument,targetFunction.evaluate(argument).asDouble())); - } - - /** Returns the list of cases generated from the ranking expression */ - @Override - public List<TrainingSet.Case> cases() { return Collections.unmodifiableList(cases); } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Recombiner.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Recombiner.java deleted file mode 100644 index 8fd40ec793f..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Recombiner.java +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.*; - -import java.util.*; -import java.util.logging.Logger; - -import static java.lang.Math.abs; -import static java.lang.Math.max; -import static java.lang.Math.min; - -/** - * A class which returns a mutated, recombined genome from a list of parent genomes. - * - * @author bratseth - */ -public class Recombiner { - - // TODO: Either make ranking expressions immutable and get rid of parent pointer, or do clone everywhere below - - private static final Logger log = Logger.getLogger(Trainer.class.getName()); - - private final Random random = new Random(); - - private final List<String> features; - - private final TrainingParameters parameters; - - /** - * Creates a recombiner - * - * @param features the list of feature names which are possible within the space we are training, - * such that these may be spontaneously added to expressions. - */ - public Recombiner(Collection<String> features, TrainingParameters trainingParameters) { - this.features = Collections.unmodifiableList(new ArrayList<>(features)); - this.parameters = trainingParameters; - } - - public RankingExpression recombine(RankingExpression genome, List<RankingExpression> genePool) { - List<ExpressionNode> genePoolRoots = new ArrayList<>(); - for (RankingExpression genePoolGenome : genePool) - genePoolRoots.add(genePoolGenome.getRoot()); - return new RankingExpression(mutate(genome.getRoot(), genePoolRoots, 0)); - } - - private ExpressionNode mutate(ExpressionNode gene, List<ExpressionNode> genePool, int depth) { - // TODO: Extract insert level - if (gene instanceof BooleanNode) - return simplifyCondition(mutateChildren((CompositeNode)gene,genePool,depth+1)); - if (gene instanceof CompositeNode) - return insertNodeLevel(simplify(removeNodeLevel(mutateChildren((CompositeNode)gene,genePool,depth+1))), genePool, depth+1); - else - return insertNodeLevel(mutateLeaf(gene), genePool, depth+1); - } - - private BooleanNode simplifyCondition(ExpressionNode node) { - // Nothing yet - return (BooleanNode)node; - } - - /** Very basic algorithmic simplification */ - private ExpressionNode simplify(ExpressionNode node) { - if (! (node instanceof CompositeNode)) return node; - CompositeNode composite = (CompositeNode)node; - if (maxDepth(composite)>2) return composite; - List<ExpressionNode> children = composite.children(); - if (children.size()!=2) return composite; - if ( ! (children.get(0) instanceof ConstantNode)) return composite; - if ( ! (children.get(1) instanceof ConstantNode)) return composite; - return new ConstantNode(composite.evaluate(null)); - } - - private CompositeNode mutateChildren(CompositeNode gene, List<ExpressionNode> genePool, int depth) { - if (gene instanceof ReferenceNode) return gene; // TODO: Remove if we make this a non-composite - - List<ExpressionNode> mutatedChildren = new ArrayList<>(); - for (ExpressionNode child : gene.children()) - mutatedChildren.add(mutate(child, genePool, depth)); - return gene.setChildren(mutatedChildren); - } - - private ExpressionNode insertNodeLevel(ExpressionNode gene, List<ExpressionNode> genePool, int depth) { - if (probability() < 0.9) return gene; - if (depth + maxDepth(gene) >= parameters.getMaxExpressionDepth()) return gene; - ExpressionNode newChild = generateChild(genePool, depth); - if (probability() < 0.5) - return generateComposite(gene, newChild, genePool, depth); - else - return generateComposite(newChild, gene, genePool, depth); - } - - private ExpressionNode removeNodeLevel(CompositeNode gene) { - if (gene instanceof ReferenceNode) return gene; // TODO: Remove if we make featurenode a non-composite - if (probability() < 0.9) return gene; - return randomFrom(gene.children()); - } - - private ExpressionNode generateComposite(ExpressionNode left, ExpressionNode right, List<ExpressionNode> genePool, int depth) { - int type = random.nextInt(2 + ( parameters.getAllowConditions() ? 1:0 ) ); // pick equally between 2 or 3 types - if (type == 0) { - return new ArithmeticNode(left, pickArithmeticOperator(), right); - } - else if (type == 1) { - Function function = pickFunction(); - if (function.arity() == 1) - return new FunctionNode(function, left); - else // arity==2 - return new FunctionNode(function, left, right); - } - else { - return new IfNode(generateCondition(genePool, depth + 1), left, right); - } - } - - private BooleanNode generateCondition(List<ExpressionNode> genePool, int depth) { - // TODO: Add set membership nodes - return new ComparisonNode(generateChild(genePool, depth), TruthOperator.SMALLER, generateChild(genePool, depth)); - } - - private ExpressionNode generateChild(List<ExpressionNode> genePool, int depth) { - if (genePool.isEmpty() || probability() < 0.1) { // entirely new child - return generateLeaf(); - } - else { // pick from gene pool - ExpressionNode picked = randomFrom(genePool); - int pickedDepth = 0; - // descend until we are at at least the same depth as this depth - // to make sure branches spliced in are shallow enough that we avoid growing - // larger than maxDepth - while (picked instanceof CompositeNode && (pickedDepth++ < depth || probability() < 0.5)) { - if (picked instanceof ReferenceNode) continue; // TODO: Remove if we make referencenode a noncomposite - picked = randomFrom(((CompositeNode)picked).children()); - } - return picked; - } - } - - public ExpressionNode mutateLeaf(ExpressionNode leaf) { - if (probability() < 0.5) return leaf; // TODO: For performance. Drop? - // TODO: Other leaves - ConstantNode constant = (ConstantNode)leaf; - return new ConstantNode(DoubleValue.frozen(constant.getValue().asDouble()*aboutOne())); - } - - public ExpressionNode generateLeaf() { - if (probability()<0.5 || features.size() == 0) - return new ConstantNode(DoubleValue.frozen(random.nextDouble() * 2000 - 1000)); // TODO: Use some non-uniform distribution - else - return new ReferenceNode(randomFrom(features)); - } - - private double aboutOne() { - return 1 + Math.pow(-0.1, random.nextInt(4) + 1); - } - - private double probability() { - return random.nextDouble(); - } - - private <T> T randomFrom(List<T> expressionList) { - return expressionList.get(random.nextInt(expressionList.size())); - } - - private ArithmeticOperator pickArithmeticOperator() { - switch (random.nextInt(4)) { - case 0: return ArithmeticOperator.PLUS; - case 1: return ArithmeticOperator.MINUS; - case 2: return ArithmeticOperator.MULTIPLY; - case 3: return ArithmeticOperator.DIVIDE; - } - throw new RuntimeException("This cannot happen"); - } - - /** Pick among the subset of functions which are probably useful */ - private Function pickFunction() { - switch (random.nextInt(5)) { - case 0: return Function.tanh; - case 1: return Function.exp; - case 2: return Function.log; - case 3: return Function.pow; - case 4: return Function.sqrt; - } - throw new RuntimeException("This cannot happen"); - } - - // TODO: Make ranking expressions immutable and compute this on creation? - private int maxDepth(ExpressionNode node) { - if ( ! (node instanceof CompositeNode)) return 1; - - int maxChildDepth = 0; - for (ExpressionNode child : ((CompositeNode)node).children()) - maxChildDepth = Math.max(maxDepth(child), maxChildDepth); - return maxChildDepth + 1; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Species.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Species.java deleted file mode 100644 index 1870f9c0afc..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Species.java +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; - -import java.util.ArrayList; -import java.util.List; - -/** - * A species is a population of evolvables. - * Contrary to a real species, a species population may contain (sub)species - * rather than individuals - at all levels but the lowest. - * - * @author bratseth - */ -public class Species extends Evolvable { - - private SpeciesName name; - private final Population population; - - /** Create a species having a given initial population */ - public Species(SpeciesName name, Population population) { - this.name = name; - this.population = population; - } - - /** Create a species evolved from a predecessor species, using the given gene pool for mutating it */ - private Species(SpeciesName name, Species predecessor, List<RankingExpression> genepool, TrainingEnvironment environment) { - this.name = name; - environment.tracker().newSpecies(predecessor, environment.parameters().getInitialSpeciesSize(), genepool); - - // Initialize new species with members generated from the predecessor species - List<Evolvable> initialMembers = new ArrayList<>(); - for (int i = 0; i < environment.parameters().getInitialSpeciesSize(); i++) - initialMembers.add(drawFrom(predecessor.population, i).makeSuccessor(i, genepool, environment)); - population = new Population(initialMembers); - - // Evolve the population of this species for the configured number of generations - environment.tracker().newSpeciesCreated(this); - for (int generation = 0; generation < environment.parameters().getSpeciesLifespan(); generation++) { - environment.tracker().iteration(this, generation+1); - population.evolve(generation, environment); - if (Double.isInfinite(bestIndividual().getFitness())) break; // jackpot - // if (keyboardChecker.isQPressed()) break; // user quit TODO: Make work - } - environment.tracker().speciesCompleted(this); - } - - /** - * Draws a member from the given population, where the probability of being drawn is proportional to the - * fitness of the member - */ - private Evolvable drawFrom(Population population, int succession) { - return population.members().get(Math.min(succession % 3, population.members().size() - 1)); // TODO: Probabilistic by fitness? - } - - public SpeciesName name() { return name; } - - /** The fitness of the fittest individual in the population */ - @Override - public double getFitness() { - return population.best().getFitness(); - } - - /** Creates the successor of this, using its genes, mutated drawing from the given gene pool */ - @Override - public Evolvable makeSuccessor(int memberNumber, List<RankingExpression> genepool, TrainingEnvironment environment) { - return new Species(name.successor(memberNumber), this, genepool, environment); - } - - /** Returns the members of this species */ - public Population population() { return population; } - - /** The genes of the fittest individual in the population of this */ - @Override - public RankingExpression getGenepool() { // TODO: Less sharp? - return population.best().getGenepool(); - } - - /** Returns the best individual below this in the species hierarchy (e.g recursively the best leaf) */ - public Individual bestIndividual() { - Evolvable child = this; - while (child instanceof Species) - child = ((Species)child).population.best(); - return (Individual)child; // it is when it is not instanceof Species - } - - @Override - public String toString() { - return "species " + name + ", best member: " + population.best(); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/SpeciesName.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/SpeciesName.java deleted file mode 100644 index 862d8e8899d..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/SpeciesName.java +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga; - -/** - * The name of a species. For tracking purposes. - * A name has the form superSpeciesName + "/" + serialNumber.generationNumber. - * - * @author bratseth - */ -public class SpeciesName { - - private final int level, serial, generation; - - private final String name, prefixName; - - private SpeciesName(int level, int serial, int generation, String prefixName) { - this.level = level; - this.serial = serial; - this.generation = generation; - this.prefixName = prefixName; - if (level == 0) - this.name = ""; - else - this.name = prefixName + (prefixName.isEmpty() ? "" : "/") + serial + "." + generation; - } - - /** - * The level in the species hierarchy of the species having this name. - * The root species has level 0. - */ - public int level() { return level; } - - /** Returns the name of the root species: The empty string at level 0 */ - public static SpeciesName createRoot() { - return new SpeciesName(0 ,0 ,0, ""); - } - - @Override - public String toString() { - if (level == 0) return "(root)"; - return name; - } - - /** Returns the name of a new subspecies */ - public SpeciesName subspecies(int serial) { - return new SpeciesName(level+1, serial, 0, name); - } - - /** Returns the name of the successor of this species */ - public SpeciesName successor(int serial) { - return new SpeciesName(level, serial, generation+1, prefixName); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Tracker.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Tracker.java deleted file mode 100644 index 0a18820560b..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Tracker.java +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; - -import java.util.List; - -/** - * A tracker receives callbacks about events happening during a training session. - * - * @author bratseth - */ -public interface Tracker { - - public void newSpecies(Species predecessor, int initialSize, List<RankingExpression> genePool); - - public void newSpeciesCreated(Species species); - - public void speciesCompleted(Species species); - - public void iteration(Species species, int generation); - - public void result(Evolvable winner); - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Trainer.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Trainer.java deleted file mode 100644 index b5268f1bb98..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/Trainer.java +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Set; - -/** - * Learns a ranking expression from some seed expressions and a training set. - * - * @author bratseth - */ -public class Trainer { - - // TODO: Simplify this to constructor only ... or maybe remove ... or combine with TrainingEnvironment - // TODO: Also: Rename to Training? - - private final TrainingSet trainingSet; - private final Set<String> argumentNames; - - /** - * Creates a new trainer. - */ - public Trainer(TrainingSet trainingSet) { - this(trainingSet, trainingSet.argumentNames()); - } - - /** - * Creates a new trainer which uses a specified list of expression argument names - * rather than the argument names given by the training set. - */ - public Trainer(TrainingSet trainingSet, Set<String> argumentNames) { - this.trainingSet = trainingSet; - this.argumentNames = new HashSet<>(argumentNames); - } - - public RankingExpression train(TrainingParameters parameters, Tracker tracker) { - TrainingEnvironment environment = new TrainingEnvironment(new Recombiner(argumentNames, parameters), tracker, trainingSet, parameters); - SpeciesName rootName = SpeciesName.createRoot(); - Species genesisSubSpecies = new Species(rootName.subspecies(0), new Population(Collections.<Evolvable>singletonList(new Individual(new RankingExpression(new ConstantNode(new DoubleValue(1))), trainingSet)))); - Species rootSpecies = (Species) new Species(rootName, new Population(Collections.<Evolvable>singletonList(genesisSubSpecies))) - .makeSuccessor(0, Collections.<RankingExpression>emptyList(), environment); - Individual winner = rootSpecies.bestIndividual(); - tracker.result(winner); - return winner.getGenome(); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/TrainingEnvironment.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/TrainingEnvironment.java deleted file mode 100644 index e874267970c..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/TrainingEnvironment.java +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga; - -/** - * The static environment of a training session - * - * @author bratseth - */ -public class TrainingEnvironment { - - // TODO: Not sure if this belongs ... or should even be an instance - // TODO: maybe collapse Trainer into this and call it TrainingSession - private final Recombiner recombiner; - private final Tracker tracker; - private final TrainingSet trainingSet; - private final TrainingParameters parameters; - - public TrainingEnvironment(Recombiner recombiner, Tracker tracker, - TrainingSet trainingSet, TrainingParameters parameters) { - this.recombiner = recombiner; - this.tracker = tracker; - this.trainingSet = trainingSet; - this.parameters = parameters; - } - - public Recombiner recombiner() { return recombiner; } - public Tracker tracker() { return tracker; } - public TrainingSet trainingSet() { return trainingSet; } - public TrainingParameters parameters() { return parameters; } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/TrainingParameters.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/TrainingParameters.java deleted file mode 100644 index 71ff8bfe259..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/TrainingParameters.java +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga; - -import java.util.HashSet; -import java.util.Set; - -/** - * @author bratseth - */ -public class TrainingParameters { - - // A note: - // The total number of species generated and evaluated is - // (generationCandidatesFactor * speciesLifespan * (initialSpeciesSize-finalSpeciesSize)/2 ) ^ speciesLevels - // (speciesLevel is hardcoded to 2 atm) - - private int speciesLifespan = 1000; - private int initialSpeciesSize = 10; - private double finalSpeciesSize = 1; - private int generationCandidatesFactor = 3; - private int maxExpressionDepth = 6; - private boolean allowConditions = true; - private boolean errorIsRelative = true; - private Set<String> excludeFeatures = new HashSet<>(); - private String trainingSetFormat = null; - private double validationFraction = 0.2; - - /** The number of generation which a given species (or super-species at any level) lives. Default:1000 */ - public int getSpeciesLifespan() { return speciesLifespan; } - public void setSpeciesLifespan(int generations) { this.speciesLifespan = generations; } - - /** The number of members in a species (or super-species at any level) as it is created. Default: 10 */ - public int getInitialSpeciesSize() { return initialSpeciesSize; } - public void setInitialSpeciesSize(int initialSpeciesSize) { this.initialSpeciesSize = initialSpeciesSize; } - - /** - * The number of members in a species in its final generation. - * The size of the species will be reduced linearly in each generation to go from initial size to final size. - * Default: 1 - */ - public double getFinalSpeciesSize() { return finalSpeciesSize; } - public void setFinalSpeciesSize(int finalSpeciesSize) { this.finalSpeciesSize = finalSpeciesSize; } - - /* - * The factor determining how many more members are generated than are allowed to survive in each generation of a species. - * Default: 3 - */ - public int getGenerationCandidatesFactor() { return generationCandidatesFactor; } - public void setGenerationCandidatesFactor(int generationCandidatesFactor) { this.generationCandidatesFactor = generationCandidatesFactor; } - - /** - * The max depth of expressions this is allowed to generate. - * Default: 6 - */ - public int getMaxExpressionDepth() { return maxExpressionDepth; } - public void setMaxExpressionDepth(int maxExpressionDepth) { this.maxExpressionDepth = maxExpressionDepth; } - - /** - * Whether mutation should allow creation of condition (if) expressions. - * Default: true - */ - public boolean getAllowConditions() { return allowConditions; } - public void setAllowConditions(boolean allowConditions) { this.allowConditions = allowConditions; } - - /** - * Whether errors are relative to the absolute value of the function at that point or not. - * If true, training will assign equal weight to the error of 1.1 for 1 and 110 for 100. - * If false, training will instead assign a 10x weight to the latter. - * Default: True. - */ - public boolean getErrorIsRelative() { return errorIsRelative; } - public void setErrorIsRelative(boolean errorIsRelative) { this.errorIsRelative = errorIsRelative; } - - /** - * Returns the set of features to exclude during training. - * Returned as an immutable set, never null. - */ - public Set<String> getExcludeFeatures() { return excludeFeatures; } - /** Sets the features to exclude from a comma-separated string */ - public void setExcludeFeatures(String excludeFeatureString) { - for (String featureName : excludeFeatureString.split(",")) - excludeFeatures.add(featureName.trim()); - } - - /** - * Returns the format of the training set to read. "fv" or "cvs" is supported. - * If this is null the format name is taken from the last name of the file instead. - * Default: null. - */ - public String getTrainingSetFormat() { return trainingSetFormat; } - public void setTrainingSetFormat(String trainingSetFormat) { this.trainingSetFormat = trainingSetFormat; } - - /** - * Returns the fraction of the result set to hold out of training and use for validation. - * Default 0.2 - */ - public double getValidationFraction() { return validationFraction; } - public void setValidationFraction(double validationFraction) { this.validationFraction = validationFraction; } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/TrainingSet.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/TrainingSet.java deleted file mode 100644 index f7917987f91..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/TrainingSet.java +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; - -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Set; - -/** - * A training set: a set of <i>cases</i>: Input data to output value pairs - * - * @author bratseth - */ -public class TrainingSet { - - private final TrainingParameters parameters; - private final List<Case> trainingCases; - private final List<Case> validationCases; - private final Set<String> argumentNames = new HashSet<>(); - - /** - * Creates a training set from a list of cases. - * The ownership of the argument list and all the cases are transferred to this by this call. - */ - public TrainingSet(CaseList caseList, TrainingParameters parameters) { - List<Case> cases = caseList.cases(); - - this.parameters = parameters; - for (Case aCase : cases) - argumentNames.addAll(aCase.arguments().names()); - argumentNames.removeAll(parameters.getExcludeFeatures()); - - int validationCaseCount = (int)Math.round((cases.size() * parameters.getValidationFraction())); - this.validationCases = cases.subList(0, validationCaseCount); - this.trainingCases = cases.subList(validationCaseCount, cases.size()); - } - - public Set<String> argumentNames() { - return Collections.unmodifiableSet(argumentNames); - } - - /** - * Returns the fitness of a genome (ranking expression) according to this training set. - * The fitness to be returned by this is the inverse of the average squared difference between the - * target function result and the function result returned by the genome function. - */ - // TODO: Take expression length into account. - public double evaluate(RankingExpression genome) { - boolean constantExpressionGenome = true; - double squaredErrorSum = 0; - Double previousValue = null; - for (Case trainingCase : trainingCases) { - double value = genome.evaluate(trainingCase.arguments()).asDouble(); - double error = saneAbs(effectiveError(trainingCase.targetValue(), value)); - squaredErrorSum += Math.pow(error, 2); - - if (previousValue != null && previousValue != value) - constantExpressionGenome = false; - previousValue = value; - } - if (constantExpressionGenome) return 0; // Disqualify constant expressions as we know we're not looking for them - return 1 / (squaredErrorSum / trainingCases.size()); - } - - private double effectiveError(double a, double b) { - return parameters.getErrorIsRelative() ? errorFraction(a, b) : a - b; - } - - /** Calculate error in a way which is easy to understand (but which behaves badly when the target is around 0 */ - public double calculateAverageError(RankingExpression genome) { - double errorSum=0; - for (Case trainingCase : trainingCases) - errorSum += saneAbs(trainingCase.targetValue() - genome.evaluate(trainingCase.arguments()).asDouble()); - return errorSum/(double) trainingCases.size(); - } - - /** Calculate error in a way which is easy to understand (but which behaves badly when the target is around 0 */ - public double calculateAverageErrorPercentage(RankingExpression genome) { - double errorFractionSum = 0; - for (Case trainingCase : trainingCases) { - double errorFraction = saneAbs(errorFraction(trainingCase.targetValue(), genome.evaluate(trainingCase.arguments()).asDouble())); - // System.out.println("Error %: " + (100 * errorFraction + " Target: " + trainingCase.targetValue() + " Learned: " + genome.evaluate(trainingCase.arguments()).asDouble())); - errorFractionSum += errorFraction; - } - return ( errorFractionSum/(double) trainingCases.size() ) *100; - } - - private double errorFraction(double a, double b) { - double error = a - b; - if (error == 0 ) return 0; // otherwise a or b is different from 0 - if (a != 0) - return error / a; - else - return error / b; - } - - private double saneAbs(double d) { - if (Double.isInfinite(d) || Double.isNaN(d)) return Double.MAX_VALUE; - return Math.abs(d); - } - - public static class Case { - - private Context arguments; - - private double targetValue; - - public Case(Context arguments, double targetValue) { - this.arguments = arguments; - this.targetValue = targetValue; - } - - public double targetValue() { return targetValue; } - - public Context arguments() { return arguments; } - - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/caselist/CsvFileCaseList.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/caselist/CsvFileCaseList.java deleted file mode 100644 index c7f4f848b71..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/caselist/CsvFileCaseList.java +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga.caselist; - -import com.yahoo.searchlib.mlr.ga.TrainingSet; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; - -import java.util.Optional; - -/** - * <p>A list of training set cases created by reading a file containing lines specifying a case - * per line using the following syntax - * <code>targetValue, argument1:value, argument2:value2, ...</code> - * where arguments are identifiers and values are doubles.</p> - * - * <p>Comment lines starting with "#" are ignored.</p> - * - * @author bratseth - */ -public class CsvFileCaseList extends FileCaseList { - - public CsvFileCaseList(String fileName) { - super(fileName); - } - - protected Optional<TrainingSet.Case> lineToCase(String line, int lineNumber) { - String[] elements = line.split(","); - if (elements.length<2) - throw new IllegalArgumentException("At line " + lineNumber + ": Expected a comma-separated case on the " + - "form 'targetValue, argument1:value1, ...', but got '" + line ); - - double target; - try { - target = Double.parseDouble(elements[0].trim()); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("At line " + lineNumber + ": Expected a target value double " + - "at the start of the line, got '" + elements[0] + "'"); - } - - Context context = new MapContext(); - for (int i=1; i<elements.length; i++) { - String[] argumentPair = elements[i].split(":"); - try { - if (argumentPair.length != 2) throw new IllegalArgumentException(); - context.put(argumentPair[0].trim(),Double.parseDouble(argumentPair[1].trim())); - } - catch (IllegalArgumentException e) { - throw new IllegalArgumentException("At line " + lineNumber + ", element " + (i+1) + - ": Expected argument on the form 'identifier:double', got '" + elements[i] + "'"); - } - } - return Optional.of(new TrainingSet.Case(context, target)); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/caselist/FileCaseList.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/caselist/FileCaseList.java deleted file mode 100644 index 35a4d58d16c..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/caselist/FileCaseList.java +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga.caselist; - -import com.yahoo.searchlib.mlr.ga.CaseList; -import com.yahoo.searchlib.mlr.ga.TrainingParameters; -import com.yahoo.searchlib.mlr.ga.TrainingSet; - -import java.io.BufferedReader; -import java.io.FileReader; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Optional; - -/** - * @author bratseth - */ -public abstract class FileCaseList implements CaseList { - - private List<TrainingSet.Case> cases = new ArrayList<>(); - - /** - * Reads a case list from file. - * - * @throws IllegalArgumentException if the file could not be found or opened - */ - public FileCaseList(String fileName) { - try (BufferedReader reader = new BufferedReader(new FileReader(fileName))) { - String line; - int lineNumber=0; - while (null != (line=reader.readLine())) { - lineNumber++; - line = line.trim(); - if (line.startsWith("#")) continue; - if (line.isEmpty()) continue; - Optional<TrainingSet.Case> newCase = lineToCase(line, lineNumber); - if (newCase.isPresent()) - cases.add(newCase.get()); - - } - } - catch (IOException | IllegalArgumentException e) { - throw new IllegalArgumentException("Could not create a case list from file '" + fileName + "'", e); - } - } - - /** Returns the case constructed from reading a line, if any */ - protected abstract Optional<TrainingSet.Case> lineToCase(String line, int lineNumber); - - @Override - public List<TrainingSet.Case> cases() { return Collections.unmodifiableList(cases); } - - /** Creates a file case list of the type specified in the parameters */ - public static FileCaseList create(String fileName, TrainingParameters parameters) { - String format = parameters.getTrainingSetFormat(); - if (format == null) - format = ending(fileName); - - switch (format) { - case "csv" : return new CsvFileCaseList(fileName); - case "fv" : return new FvFileCaseList(fileName); - default : throw new IllegalArgumentException("Unknown file format '" + format + "'"); - } - } - - private static String ending(String fileName) { - int lastDot = fileName.lastIndexOf("."); - if (lastDot <= 0) return null; - return fileName.substring(lastDot + 1, fileName.length()); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/caselist/FvFileCaseList.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/caselist/FvFileCaseList.java deleted file mode 100644 index 0e5b2aac729..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/ga/caselist/FvFileCaseList.java +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga.caselist; - -import com.yahoo.searchlib.mlr.ga.TrainingSet; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; - -import java.util.Optional; - -/** - * A list of training set cases created by reading a file containing lines specifying a case - * per line using the following syntax - * <code>feature1\tfeature2\tfeature3\t...\ttarget1</code> - * <p> - * The first line contains the name of each feature in the same order. - * - * <p>Comment lines starting with "#" are ignored.</p> - * - * @author bratseth - */ -// NOTE: If we get another type of case list it is time to abstract into a common CaseList base class -public class FvFileCaseList extends FileCaseList { - - private String[] argumentNames; - - public FvFileCaseList(String fileName) { - super(fileName); - } - - protected Optional<TrainingSet.Case> lineToCase(String line, int lineNumber) { - String[] values = line.split("\t"); - - if (argumentNames == null) { // first line - argumentNames = values; - return Optional.empty(); - } - - if (argumentNames.length != values.length) - throw new IllegalArgumentException("Wrong number of values at line " + lineNumber); - - - Context context = new MapContext(); - for (int i = 0; i < values.length-1; i++) - context.put(argumentNames[i], toDouble(values[i], lineNumber)); - - double target = toDouble(values[values.length-1], lineNumber); - return Optional.of(new TrainingSet.Case(context, target)); - } - - private double toDouble(String s, int lineNumber) { - try { - return Double.parseDouble(s.trim()); - } catch (NumberFormatException e) { - throw new IllegalArgumentException("At line " + lineNumber + ": Expected only double values, " + - "got '" + s + "'"); - } - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/mlr/gbdt/ExpressionAnalysis.java b/searchlib/src/main/java/com/yahoo/searchlib/mlr/gbdt/ExpressionAnalysis.java deleted file mode 100644 index 16a4f6f931b..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/mlr/gbdt/ExpressionAnalysis.java +++ /dev/null @@ -1,425 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.gbdt; - -import com.yahoo.searchlib.rankingexpression.rule.SetMembershipNode; -import com.yahoo.yolean.Exceptions; -import com.yahoo.searchlib.mlr.ga.Individual; -import com.yahoo.searchlib.mlr.ga.PrintingTracker; -import com.yahoo.searchlib.mlr.ga.RankingExpressionCaseList; -import com.yahoo.searchlib.mlr.ga.Trainer; -import com.yahoo.searchlib.mlr.ga.TrainingParameters; -import com.yahoo.searchlib.mlr.ga.TrainingSet; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.rule.Arguments; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; -import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.IfNode; -import com.yahoo.searchlib.rankingexpression.rule.NegativeNode; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; - -import java.io.BufferedReader; -import java.io.FileNotFoundException; -import java.io.FileReader; -import java.io.IOException; -import java.io.Reader; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Random; - -/** - * A standalone tool which analyzes a GBDT form ranking expression - * - * @author bratseth - */ -public class ExpressionAnalysis { - - private final Map<String, Feature> features = new HashMap<>(); - - private int currentTree; - - private final RankingExpression expression; - - public ExpressionAnalysis(RankingExpression expression) { - this.expression = expression; - if ( ! instanceOf(expression.getRoot(), ArithmeticNode.class)) return; - analyzeSum((ArithmeticNode)expression.getRoot()); - } - - /** Returns the expression analyzed by this */ - public RankingExpression expression() { return expression; } - - /** Returns the analysis of each feature in this expression as a read-only map indexed by feature name */ - private Map<String, Feature> featureMap() { - return Collections.unmodifiableMap(features); - } - - /** Returns list containing the analysis of each feature, sorted by decreasing usage */ - private List<Feature> features() { - List<Feature> featureList = new ArrayList<>(features.values()); - Collections.sort(featureList); - return featureList; - } - - /** Returns the name of each feature, sorted by decreasing usage */ - private List<String> featureNames() { - List<String> featureNameList = new ArrayList<>(features.values().size()); - for (Feature feature : features()) - featureNameList.add(feature.name()); - return featureNameList; - } - - private void analyzeSum(ArithmeticNode node) { - for (ExpressionNode child : node.children()) { - currentTree++; - analyze(child); - } - } - - private void analyze(ExpressionNode node) { - if (node instanceof IfNode) { - analyzeIf((IfNode)node); - } - - if (node instanceof CompositeNode) { - for (ExpressionNode child : ((CompositeNode)node).children()) - analyze(child); - } - } - - private void analyzeIf(IfNode node) { - if (node.getCondition() instanceof ComparisonNode) - analyzeComparisonIf(node); - else if (node.getCondition() instanceof SetMembershipNode) - analyzeSetMembershipIf(node); - else - System.err.println("Warning: Expected a comparison or set membership test, got " + node.getCondition().getClass()); - } - - private void analyzeComparisonIf(IfNode node) { - ComparisonNode comparison = (ComparisonNode)node.getCondition(); - - if (comparison.getOperator() != TruthOperator.SMALLER) { - System.err.println("Warning: This expression has " + comparison.getOperator() + " where we expect < :" + - comparison); - return; - } - - if ( ! instanceOf(comparison.getLeftCondition(), ReferenceNode.class)) return; - String featureName = ((ReferenceNode)comparison.getLeftCondition()).getName(); - - Double value = nodeValue(comparison.getRightCondition()); - if (value == null) return; - - ComparisonFeature feature = (ComparisonFeature)features.get(featureName); - if (feature == null) { - feature = new ComparisonFeature(featureName); - features.put(featureName, feature); - } - feature.isComparedTo(value, currentTree, average(node.getTrueExpression()), average(node.getFalseExpression())); - } - - private void analyzeSetMembershipIf(IfNode node) { - SetMembershipNode membershipTest = (SetMembershipNode)node.getCondition(); - - if ( ! instanceOf(membershipTest.getTestValue(), ReferenceNode.class)) return; - String featureName = ((ReferenceNode)membershipTest.getTestValue()).getName(); - - SetMembershipFeature feature = (SetMembershipFeature)features.get(featureName); - if (feature == null) { - feature = new SetMembershipFeature(featureName); - features.put(featureName, feature); - } - } - - /** - * Returns the value of a constant node, or a negative wrapping a constant. - * Warns and returns null if it is neither. - */ - private Double nodeValue(ExpressionNode node) { - if (node instanceof NegativeNode) { - NegativeNode negativeNode = (NegativeNode)node; - if ( ! instanceOf(negativeNode.getValue(), ConstantNode.class)) return null; - return - ((ConstantNode)negativeNode.getValue()).getValue().asDouble(); - } - else { - if ( ! instanceOf(node, ConstantNode.class)) return null; - return ((ConstantNode)node).getValue().asDouble(); - } - } - - - /** Returns the average value of all the leaf constants below this */ - private double average(ExpressionNode node) { - Sum sum = new Sum(); - average(node, sum); - return sum.average(); - } - - private void average(ExpressionNode node, Sum sum) { - if (node instanceof CompositeNode) { - for (ExpressionNode child : ((CompositeNode)node).children()) - average(child, sum); - } - else { - Double value = nodeValue(node); - if (value == null) return; - sum.add(value); - } - } - - private boolean instanceOf(Object object, Class<?> clazz) { - if (clazz.isAssignableFrom(object.getClass())) return true; - System.err.println("Warning: This expression has " + object.getClass() + " where we expect " + clazz + - ": Instance " + object); - return false; - } - - private List<Context> generateArgumentSets(int count) { - List<Context> argumentSets = new ArrayList<>(count); - for (int i=0; i<count; i++) { - ArgumentIgnoringMapContext context = new ArgumentIgnoringMapContext(); - for (Feature feature : features()) { - if (feature instanceof ComparisonFeature) { - ComparisonFeature comparison = (ComparisonFeature)feature; - context.put(comparison.name(),randomBetween(comparison.lowerBound(), comparison.upperBound())); - } - // TODO: else if (feature instanceof SetMembershipFeature) - } - argumentSets.add(context); - } - return argumentSets; - } - - private Random random = new Random(); - /** Returns a random value in [lowerBound, upperBound> */ - private double randomBetween(double lowerBound, double upperBound) { - return random.nextDouble()*(upperBound-lowerBound)+lowerBound; - } - - private static class ArgumentIgnoringMapContext extends MapContext { - - @Override - public Value get(String name, Arguments arguments,String output) { - return super.get(name, null, output); - } - - } - - /** Generates a textual report from analyzing this expression */ - public String report() { - StringBuilder b = new StringBuilder(); - b.append("Trees: " + currentTree).append("\n"); - b.append("Features:\n"); - for (Feature feature : features()) - b.append(" " + feature).append("\n"); - return b.toString(); - } - - private static final String usage = "\nUsage: ExpressionAnalysis [myExpressionFile.expression]"; - - public static void main(String[] args) { - if (args.length == 0) error("No arguments." + usage); - - ExpressionAnalysis analysis = analysisFromFile(args[0]); - - if (1==1) return; // Turn off ga training - if (args.length == 1) { - new GATraining(analysis); - } - else if (args.length == 2) { - try { - new LearntExpressionAnalysis(analysis, new RankingExpression(args[1])); - } - catch (ParseException e) { - error("Syntax error in argument expression: " + Exceptions.toMessageString(e)); - } - } - else { - error("Unexpectedly got more than 2 arguments." + usage); - } - - } - - private static ExpressionAnalysis analysisFromFile(String fileName) { - try (Reader fileReader = new BufferedReader(new FileReader(fileName))) { - System.out.println("Analyzing " + fileName + "..."); - ExpressionAnalysis analysis = new ExpressionAnalysis(new RankingExpression(fileReader)); - System.out.println(analysis.report()); - return analysis; - } - catch (FileNotFoundException e) { - error("Could not find '" + fileName + "'"); - } - catch (IOException e) { - error("Failed reading '" + fileName + "': " + Exceptions.toMessageString(e)); - } - catch (ParseException e) { - error("Syntax error in '" + fileName + "': " + Exceptions.toMessageString(e)); - } - return null; - } - - private static class LearntExpressionAnalysis { - - public LearntExpressionAnalysis(ExpressionAnalysis analysis, RankingExpression learntExpression) { - int cases = 1000; - TrainingSet newTrainingSet = new TrainingSet(new RankingExpressionCaseList(analysis.generateArgumentSets(cases), - analysis.expression()), new TrainingParameters()); - Individual winner = new Individual(learntExpression, newTrainingSet); - System.out.println("With separate training set: " + winner.toShortString() + " (" + winner.calculateAverageError() + ")"); - } - - } - - private static class GATraining { - - public GATraining(ExpressionAnalysis analysis) { - int skipFeatures = 0; - int featureCount = analysis.featureNames().size(); - int cases = 1000; - TrainingParameters parameters = new TrainingParameters(); - parameters.setInitialSpeciesSize(50); - parameters.setSpeciesLifespan(50); - //parameters.setAllowConditions(false); // disallow non-smooth functions - parameters.setMaxExpressionDepth(8); - TrainingSet trainingSet = new TrainingSet(new RankingExpressionCaseList(analysis.generateArgumentSets(cases), - analysis.expression()), parameters); - Trainer trainer = new Trainer(trainingSet, new HashSet<>(analysis.featureNames().subList(skipFeatures, featureCount))); - - System.out.println("Learning ..."); - RankingExpression learntExpression = trainer.train(parameters, new PrintingTracker(100, 0, 1)); - System.out.println("Learnt expression: " + learntExpression); - - // Check for overtraining - new LearntExpressionAnalysis(analysis, learntExpression); - } - - } - - private static void error(String message) { - System.err.println(message); - System.exit(1); - } - - public abstract static class Feature implements Comparable<Feature> { - - private final String name; - - protected Feature(String name) { - this.name = name; - } - - public String name() { return name; } - - /** Primary sort by type, secondary by name */ - @Override - public int compareTo(Feature other) { - int typeComparison = this.getClass().getName().compareTo(other.getClass().getName()); - if (typeComparison != 0) return typeComparison; - return this.name.compareTo(other.name); - } - - } - - /** A feature used in comparisons. These are the ones on which our serious analysis is focused */ - public static class ComparisonFeature extends Feature { - - private double lowerBound = Double.MAX_VALUE; - private double upperBound = Double.MIN_VALUE; - - /** The number of usages of this feature */ - private int usages = 0; - - /** The sum of the tree numbers where this is accessed */ - private int treeNumberSum = 0; - - /** - * The net times where the left values are smaller than the right values for this - * (which is a measure of correlation between input and output because the comparison is <) - */ - private int correlationCount = 0; - - /** - * The sum difference in returned value between choosing the right and left branch due to this feature - */ - private double netSum = 0; - - public ComparisonFeature(String name) { - super(name); - } - - public double lowerBound() { return lowerBound; } - public double upperBound() { return upperBound; } - - public void isComparedTo(double value, int inTreeNumber, double leftAverage, double rightAverage) { - lowerBound = Math.min(lowerBound, value); - upperBound = Math.max(upperBound, value); - usages++; - treeNumberSum += inTreeNumber; - correlationCount += leftAverage < rightAverage ? 1 : -1; - netSum += rightAverage - leftAverage; - } - - /** Override to do secondary sort by usages */ - public int compareTo(Feature o) { - if ( ! (o instanceof ComparisonFeature)) return super.compareTo(o); - ComparisonFeature other = (ComparisonFeature)o; - return - Integer.compare(this.usages, other.usages); - } - - @Override - public String toString() { - return "Numeric feature: " + name() + - ": range [" + lowerBound + ", " + upperBound + "]" + - ", usages " + usages + - ", average tree occurrence " + (treeNumberSum / usages) + - ", correlation: " + (correlationCount / (double)usages) + - ", net contribution: " + netSum; - } - - } - - /** A feature used in set membership tests */ - public static class SetMembershipFeature extends Feature { - - public SetMembershipFeature(String name) { - super(name); - } - - @Override - public String toString() { - return "Categorical feature: " + name(); - } - - } - - /** A sum which can returns its average */ - private static class Sum { - - private double sum; - private int count; - - public void add(double value) { - sum+=value; - count++; - } - - public double average() { - return sum / count; - } - - } - -} diff --git a/searchlib/src/main/sh/ga b/searchlib/src/main/sh/ga deleted file mode 100644 index 009b9684160..00000000000 --- a/searchlib/src/main/sh/ga +++ /dev/null @@ -1,67 +0,0 @@ -#! /bin/sh -# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -# BEGIN environment bootstrap section -# Do not edit between here and END as this section should stay identical in all scripts - -findpath () { - myname=${0} - mypath=${myname%/*} - myname=${myname##*/} - if [ "$mypath" ] && [ -d "$mypath" ]; then - return - fi - mypath=$(pwd) - if [ -f "${mypath}/${myname}" ]; then - return - fi - echo "FATAL: Could not figure out the path where $myname lives from $0" - exit 1 -} - -COMMON_ENV=libexec/vespa/common-env.sh - -source_common_env () { - if [ "$VESPA_HOME" ] && [ -d "$VESPA_HOME" ]; then - export VESPA_HOME - common_env=$VESPA_HOME/$COMMON_ENV - if [ -f "$common_env" ]; then - . $common_env - return - fi - fi - return 1 -} - -findroot () { - source_common_env && return - if [ "$VESPA_HOME" ]; then - echo "FATAL: bad VESPA_HOME value '$VESPA_HOME'" - exit 1 - fi - if [ "$ROOT" ] && [ -d "$ROOT" ]; then - VESPA_HOME="$ROOT" - source_common_env && return - fi - findpath - while [ "$mypath" ]; do - VESPA_HOME=${mypath} - source_common_env && return - mypath=${mypath%/*} - done - echo "FATAL: missing VESPA_HOME environment variable" - echo "Could not locate $COMMON_ENV anywhere" - exit 1 -} - -findroot - -# END environment bootstrap section - -JAR=$VESPA_HOME/lib/jars/searchlib-deploy.jar -if [[ "$1" == *.jar ]]; then - JAR=$1 -fi -shift - -exec java -cp $JAR com.yahoo.searchlib.mlr.ga.Main "$@" diff --git a/searchlib/src/main/sh/gbdt-analysis b/searchlib/src/main/sh/gbdt-analysis deleted file mode 100755 index eae755689b0..00000000000 --- a/searchlib/src/main/sh/gbdt-analysis +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -java -cp target/searchlib.jar com.yahoo.searchlib.mlr.gbdt.ExpressionAnalysis $@ diff --git a/searchlib/src/test/java/com/yahoo/searchlib/mlr/ga/test/CsvFileCaseListTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/mlr/ga/test/CsvFileCaseListTestCase.java deleted file mode 100644 index 68f705315ad..00000000000 --- a/searchlib/src/test/java/com/yahoo/searchlib/mlr/ga/test/CsvFileCaseListTestCase.java +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga.test; - -import com.yahoo.searchlib.mlr.ga.TrainingParameters; -import com.yahoo.searchlib.mlr.ga.caselist.CsvFileCaseList; -import com.yahoo.yolean.Exceptions; -import com.yahoo.searchlib.mlr.ga.TrainingSet; -import org.junit.Test; -import static org.junit.Assert.*; - -/** - * @author bratseth - */ -public class CsvFileCaseListTestCase { - - private static final double delta = 0.000001; - - @Test - public void testLegalFile() { - CsvFileCaseList list = new CsvFileCaseList("src/test/files/mlr/cases.csv"); - - assertEquals(3,list.cases().size()); - { - TrainingSet.Case case1 = list.cases().get(0); - assertEquals(1.0, case1.targetValue(), delta); - assertEquals(2, case1.arguments().names().size()); - assertEquals(2.0, case1.arguments().get("arg1").asDouble(),delta); - assertEquals(-1.3, case1.arguments().get("arg2").asDouble(),delta); - } - - { - TrainingSet.Case case2 = list.cases().get(1); - assertEquals(-1.003, case2.targetValue(), delta); - assertEquals(1, case2.arguments().names().size()); - assertEquals(500007, case2.arguments().get("arg1").asDouble(),delta); - } - - { - TrainingSet.Case case3 = list.cases().get(2); - assertEquals(0, case3.targetValue(), delta); - assertEquals(1, case3.arguments().names().size()); - assertEquals(1.0, case3.arguments().get("arg2").asDouble(),delta); - } - - TrainingSet trainingSet = new TrainingSet(list, new TrainingParameters()); - assertEquals(2, trainingSet.argumentNames().size()); - assertTrue(trainingSet.argumentNames().contains("arg1")); - assertTrue(trainingSet.argumentNames().contains("arg2")); - } - - @Test - public void testNonExistingFile() { - try { - new CsvFileCaseList("nosuchfile"); - } - catch (IllegalArgumentException e) { - assertEquals("Could not create a case list from file 'nosuchfile': nosuchfile (No such file or directory)", Exceptions.toMessageString(e)); - } - } - - @Test - public void testInvalidFile1() { - try { - new CsvFileCaseList("src/test/files/mlr/cases-illegal1.csv"); - } - catch (IllegalArgumentException e) { - assertEquals("Could not create a case list from file 'src/test/files/mlr/cases-illegal1.csv': At line 5, element 3: Expected argument on the form 'identifier:double', got ' arg2:'", Exceptions.toMessageString(e)); - } - } - - @Test - public void testInvalidFile2() { - try { - new CsvFileCaseList("src/test/files/mlr/cases-illegal2.csv"); - } - catch (IllegalArgumentException e) { - assertEquals("Could not create a case list from file 'src/test/files/mlr/cases-illegal2.csv': At line 2: Expected a target value double at the start of the line, got '5db'", Exceptions.toMessageString(e)); - } - } - -} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/mlr/ga/test/ExampleLearningSessions.java b/searchlib/src/test/java/com/yahoo/searchlib/mlr/ga/test/ExampleLearningSessions.java deleted file mode 100644 index 4de83d16300..00000000000 --- a/searchlib/src/test/java/com/yahoo/searchlib/mlr/ga/test/ExampleLearningSessions.java +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga.test; - -import com.yahoo.searchlib.mlr.ga.PrintingTracker; -import com.yahoo.searchlib.mlr.ga.RankingExpressionCaseList; -import com.yahoo.searchlib.mlr.ga.Trainer; -import com.yahoo.searchlib.mlr.ga.TrainingParameters; -import com.yahoo.searchlib.mlr.ga.TrainingSet; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; -import com.yahoo.searchlib.rankingexpression.parser.ParseException; - -import java.util.ArrayList; -import java.util.List; - -/** - * Main class - drives a learning session from the command line. - * - * @author bratseth - */ -public class ExampleLearningSessions { - - public static void main(String[] args) throws ParseException { - test3(); - } - - // Always learnt precisely in less than a second - private static void test1() throws ParseException { - TrainingParameters parameters = new TrainingParameters(); - - RankingExpression target = new RankingExpression("2*x"); - List<Context> arguments = new ArrayList<>(); - arguments.add(MapContext.fromString("x:0").freeze()); - arguments.add(MapContext.fromString("x:1").freeze()); - arguments.add(MapContext.fromString("x:2").freeze()); - TrainingSet trainingSet = new TrainingSet(new RankingExpressionCaseList(arguments, target), parameters); - - Trainer trainer = new Trainer(trainingSet); - - System.out.println("Learning ..."); - RankingExpression learntExpression = trainer.train(parameters, new PrintingTracker()); - } - - // Solved well in a few seconds at most. Slow going thereafter. - private static void test2() throws ParseException { - TrainingParameters parameters = new TrainingParameters(); - parameters.setSpeciesLifespan(100); // Shorter lifespan is faster? - - RankingExpression target = new RankingExpression("5*x*x + 2*x + 13"); - List<Context> arguments = new ArrayList<>(); - arguments.add(MapContext.fromString("x:0").freeze()); - arguments.add(MapContext.fromString("x:1").freeze()); - arguments.add(MapContext.fromString("x:2").freeze()); - arguments.add(MapContext.fromString("x:3").freeze()); - arguments.add(MapContext.fromString("x:4").freeze()); - arguments.add(MapContext.fromString("x:5").freeze()); - arguments.add(MapContext.fromString("x:6").freeze()); - arguments.add(MapContext.fromString("x:7").freeze()); - arguments.add(MapContext.fromString("x:8").freeze()); - arguments.add(MapContext.fromString("x:9").freeze()); - arguments.add(MapContext.fromString("x:10").freeze()); - arguments.add(MapContext.fromString("x:50").freeze()); - arguments.add(MapContext.fromString("x:500").freeze()); - arguments.add(MapContext.fromString("x:5000").freeze()); - arguments.add(MapContext.fromString("x:50000").freeze()); - TrainingSet trainingSet = new TrainingSet(new RankingExpressionCaseList(arguments, target), parameters); - - Trainer trainer = new Trainer(trainingSet); - - System.out.println("Learning ..."); - RankingExpression learntExpression = trainer.train(parameters, new PrintingTracker()); - } - - // Solved well in at most a few minutes - private static void test3() throws ParseException { - TrainingParameters parameters = new TrainingParameters(); - parameters.setAllowConditions(false); // disallow non-smooth functions: Speeds up learning of smooth ones greatly - - RankingExpression target = new RankingExpression("-2.7*x*x*x + 5*x*x + 2*x + 13"); - List<Context> arguments = new ArrayList<>(); - arguments.add(MapContext.fromString("x:-50000").freeze()); - arguments.add(MapContext.fromString("x:-5000").freeze()); - arguments.add(MapContext.fromString("x:-500").freeze()); - arguments.add(MapContext.fromString("x:-50").freeze()); - arguments.add(MapContext.fromString("x:-10").freeze()); - arguments.add(MapContext.fromString("x:0").freeze()); - arguments.add(MapContext.fromString("x:1").freeze()); - arguments.add(MapContext.fromString("x:2").freeze()); - arguments.add(MapContext.fromString("x:3").freeze()); - arguments.add(MapContext.fromString("x:4").freeze()); - arguments.add(MapContext.fromString("x:5").freeze()); - arguments.add(MapContext.fromString("x:6").freeze()); - arguments.add(MapContext.fromString("x:7").freeze()); - arguments.add(MapContext.fromString("x:8").freeze()); - arguments.add(MapContext.fromString("x:9").freeze()); - arguments.add(MapContext.fromString("x:10").freeze()); - arguments.add(MapContext.fromString("x:50").freeze()); - arguments.add(MapContext.fromString("x:500").freeze()); - arguments.add(MapContext.fromString("x:5000").freeze()); - arguments.add(MapContext.fromString("x:50000").freeze()); - TrainingSet trainingSet = new TrainingSet(new RankingExpressionCaseList(arguments, target), parameters); - - Trainer trainer = new Trainer(trainingSet); - - System.out.println("Learning ..."); - RankingExpression learntExpression = trainer.train(parameters, new PrintingTracker()); - } - -} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/mlr/ga/test/MainTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/mlr/ga/test/MainTestCase.java deleted file mode 100644 index f5febe2ab68..00000000000 --- a/searchlib/src/test/java/com/yahoo/searchlib/mlr/ga/test/MainTestCase.java +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga.test; - -import com.yahoo.searchlib.mlr.ga.Evolvable; -import com.yahoo.searchlib.mlr.ga.Main; -import com.yahoo.searchlib.mlr.ga.PrintingTracker; -import com.yahoo.searchlib.mlr.ga.Species; -import com.yahoo.searchlib.mlr.ga.Tracker; -import com.yahoo.searchlib.mlr.ga.TrainingParameters; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import org.junit.Test; -import static org.junit.Assert.*; - -import java.util.List; - -/** - * Tests the main class used from the command line - * - * @author bratseth - */ -public class MainTestCase { - - /** Tests that an extremely simple function expressed as cases in a file is learnt perfectly. */ - @Test - public void testMain() { - SilentTestTracker tracker = new SilentTestTracker(); - new Main(new String[] { "src/test/files/mlr/cases-linear.csv"}, tracker); - assertTrue(Double.isInfinite(tracker.winner.getFitness())); - } - - private static class SilentTestTracker implements Tracker { - - public Evolvable winner; - - @Override - public void newSpecies(Species predecessor, int initialSize, List<RankingExpression> genePool) { - } - - @Override - public void newSpeciesCreated(Species predecessor) { - } - - @Override - public void speciesCompleted(Species predecessor) { - } - - @Override - public void iteration(Species species, int generation) { - } - - @Override - public void result(Evolvable winner) { - this.winner = winner; - } - } - -} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/mlr/ga/test/MockTrainingSetTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/mlr/ga/test/MockTrainingSetTestCase.java deleted file mode 100644 index 2fc6e6cab3d..00000000000 --- a/searchlib/src/test/java/com/yahoo/searchlib/mlr/ga/test/MockTrainingSetTestCase.java +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga.test; - -import com.yahoo.searchlib.mlr.ga.RankingExpressionCaseList; -import com.yahoo.searchlib.mlr.ga.TrainingParameters; -import com.yahoo.searchlib.mlr.ga.TrainingSet; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; -import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import org.junit.Test; -import static org.junit.Assert.*; - -import java.util.ArrayList; -import java.util.List; - -/** - * @author bratseth - */ -public class MockTrainingSetTestCase { - - @Test - public void testMockTrainingSet() throws ParseException { - RankingExpression target = new RankingExpression("2*x"); - List<Context> arguments = new ArrayList<>(); - arguments.add(MapContext.fromString("x:0")); - arguments.add(MapContext.fromString("x:1")); - arguments.add(MapContext.fromString("x:2")); - TrainingSet trainingSet = new TrainingSet(new RankingExpressionCaseList(arguments, target), new TrainingParameters()); - assertTrue(Double.isInfinite(trainingSet.evaluate(new RankingExpression("2*x")))); - assertEquals(4.0, trainingSet.evaluate(new RankingExpression("x")), 0.001); - assertEquals(0.0, trainingSet.evaluate(new RankingExpression("x/x")), 0.001); - } - - @Test - public void testEvaluation() throws ParseException { - // with freezing - assertEquals(16.0,new RankingExpression("2*x*x*x").evaluate(MapContext.fromString("x:2").freeze()).asDouble(),0.0001); - assertEquals(8.0,new RankingExpression("x*x+x*x").evaluate(MapContext.fromString("x:2").freeze()).asDouble(),0.0001); - - // without freezing - assertEquals(16.0,new RankingExpression("2*x*x*x").evaluate(MapContext.fromString("x:2")).asDouble(),0.0001); - assertEquals(8.0,new RankingExpression("x*x+x*x").evaluate(MapContext.fromString("x:2")).asDouble(),0.0001); - } - -} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/mlr/ga/test/TripAdvisorFileCaseList.java b/searchlib/src/test/java/com/yahoo/searchlib/mlr/ga/test/TripAdvisorFileCaseList.java deleted file mode 100644 index 7945e2605b0..00000000000 --- a/searchlib/src/test/java/com/yahoo/searchlib/mlr/ga/test/TripAdvisorFileCaseList.java +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.ga.test; - -import com.yahoo.searchlib.mlr.ga.CaseList; -import com.yahoo.searchlib.mlr.ga.TrainingSet; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; - -import java.io.BufferedReader; -import java.io.FileReader; -import java.io.IOException; -import java.util.*; - -/** - * Reads a tripadvisor Kaggle challenge training set - * - * @author bratseth - */ -public class TripAdvisorFileCaseList implements CaseList { - - private List<TrainingSet.Case> cases = new ArrayList<>(); - private Map<Integer,String> columnNames = new HashMap<>(); - - /** - * Reads a case list from file. - * - * @throws IllegalArgumentException if the file could not be found or opened - */ - public TripAdvisorFileCaseList(String fileName) throws IllegalArgumentException { - System.out.print("Reading training data "); - try (BufferedReader reader = new BufferedReader(new FileReader(fileName))) { - String line; - readColumnNames(reader.readLine()); - int lineNumber=1; - while (null != (line=reader.readLine())) { - lineNumber++; - line = line.trim(); - if (line.startsWith("#")) continue; - if (line.isEmpty()) continue; - cases.add(lineToCase(line, lineNumber)); - } - } - catch (IOException | IllegalArgumentException e) { - throw new IllegalArgumentException("Could not create a case list from file '" + fileName + "'", e); - } - System.out.println("done"); - } - - private void readColumnNames(String line) { - int columnNumber = 0; - for (String columnName : line.split(",")) - columnNames.put(columnNumber++, columnName); - } - - protected TrainingSet.Case lineToCase(String line, int lineNumber) { - if ((lineNumber % 10000) ==0) - System.out.print("."); - - Map<String,Double> columnValues = readColumns(line); - - double targetValue = columnValues.get("click_bool") + columnValues.get("booking_bool")*5; - - Context context = new MapContext(); - for (Map.Entry<String,Double> value : columnValues.entrySet()) { - if (value.getKey().equals("click_bool")) continue; - if (value.getKey().equals("gross_bookings_usd")) continue; - if (value.getKey().equals("booking_bool")) continue; - context.put(value.getKey(),value.getValue()); - } - return new TrainingSet.Case(context, targetValue); - } - - private Map<String, Double> readColumns(String line) { - Map<String,Double> columnValues = new LinkedHashMap<>(); - int columnNumber = 0; - for (String valueString : line.split(",")) { - String columnName = columnNames.get(columnNumber++); - if (columnName.equals("date_time")) continue; - Double columnValue; - if (valueString.equals("NULL")) { - columnValue = 0.0; - } - else { - try { - columnValue = Double.parseDouble(valueString); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("Could not parse column '" + columnName + "'",e); - } - } - columnValues.put(columnName, columnValue); - } - return columnValues; - } - - @Override - public List<TrainingSet.Case> cases() { return Collections.unmodifiableList(cases); } - -} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/mlr/gbdt/ExpressionAnalysisRunner.java b/searchlib/src/test/java/com/yahoo/searchlib/mlr/gbdt/ExpressionAnalysisRunner.java deleted file mode 100644 index 28f90ebb0fc..00000000000 --- a/searchlib/src/test/java/com/yahoo/searchlib/mlr/gbdt/ExpressionAnalysisRunner.java +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.mlr.gbdt; - -import org.junit.Ignore; -import org.junit.Test; - -/** - * Run an expression analyser without having to muck with classpath. - * - * @author bratseth - */ -public class ExpressionAnalysisRunner { - - @Test @Ignore - public void runAnalysis() { - ExpressionAnalysis.main(new String[] { "/Users/bratseth/Downloads/getty_mlr_001.expression"}); - } - -} |