summaryrefslogtreecommitdiffstats
path: root/libmlr
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-06-15 23:09:44 +0200
committerJon Bratseth <bratseth@yahoo-inc.com>2016-06-15 23:09:44 +0200
commit72231250ed81e10d66bfe70701e64fa5fe50f712 (patch)
tree2728bba1131a6f6e5bdf95afec7d7ff9358dac50 /libmlr
Publish
Diffstat (limited to 'libmlr')
-rw-r--r--libmlr/OWNERS2
-rwxr-xr-xlibmlr/bin/xml2cpp72
-rw-r--r--libmlr/pom.xml53
-rw-r--r--libmlr/src/main/java/com/yahoo/yst/libmlr/converter/DecisionTreeXmlToCpp.java539
-rw-r--r--libmlr/src/main/java/com/yahoo/yst/libmlr/converter/MlrCodeGenException.java18
-rw-r--r--libmlr/src/main/java/com/yahoo/yst/libmlr/converter/XmlUtils.java65
-rw-r--r--libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/EarlyExit.java28
-rw-r--r--libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Epilog.java16
-rw-r--r--libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/FuncNormalize.java211
-rw-r--r--libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/FuncPolytransform.java50
-rw-r--r--libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Function.java8
-rw-r--r--libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/InternalNode.java61
-rw-r--r--libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/MlrFunction.java31
-rw-r--r--libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Operator.java36
-rw-r--r--libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/ResponseNode.java21
-rw-r--r--libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Tree.java57
-rw-r--r--libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreeNode.java39
-rw-r--r--libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreeNodeVisitor.java8
-rw-r--r--libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreenetFunction.java103
-rw-r--r--libmlr/src/main/java/com/yahoo/yst/libmlr/converter/parser/DecisionTreeXmlException.java17
-rw-r--r--libmlr/src/main/java/com/yahoo/yst/libmlr/converter/parser/MlrXmlParser.java435
-rw-r--r--libmlr/src/main/java/config/header_template.txt17
22 files changed, 1887 insertions, 0 deletions
diff --git a/libmlr/OWNERS b/libmlr/OWNERS
new file mode 100644
index 00000000000..76e34e72c9d
--- /dev/null
+++ b/libmlr/OWNERS
@@ -0,0 +1,2 @@
+lesters
+tmartins
diff --git a/libmlr/bin/xml2cpp b/libmlr/bin/xml2cpp
new file mode 100755
index 00000000000..6367d745f78
--- /dev/null
+++ b/libmlr/bin/xml2cpp
@@ -0,0 +1,72 @@
+#!/bin/bash
+# Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+# BEGIN environment bootstrap section
+# Do not edit between here and END as this section should stay identical in all scripts
+
+findpath () {
+ myname=${0}
+ mypath=${myname%/*}
+ myname=${myname##*/}
+ if [ "$mypath" ] && [ -d "$mypath" ]; then
+ return
+ fi
+ mypath=$(pwd)
+ if [ -f "${mypath}/${myname}" ]; then
+ return
+ fi
+ echo "FATAL: Could not figure out the path where $myname lives from $0"
+ exit 1
+}
+
+COMMON_ENV=libexec/vespa/common-env.sh
+
+source_common_env () {
+ if [ "$VESPA_HOME" ] && [ -d "$VESPA_HOME" ]; then
+ # ensure it ends with "/" :
+ VESPA_HOME=${VESPA_HOME%/}/
+ export VESPA_HOME
+ common_env=$VESPA_HOME/$COMMON_ENV
+ if [ -f "$common_env" ]; then
+ . $common_env
+ return
+ fi
+ fi
+ return 1
+}
+
+findroot () {
+ source_common_env && return
+ if [ "$VESPA_HOME" ]; then
+ echo "FATAL: bad VESPA_HOME value '$VESPA_HOME'"
+ exit 1
+ fi
+ if [ "$ROOT" ] && [ -d "$ROOT" ]; then
+ VESPA_HOME="$ROOT"
+ source_common_env && return
+ fi
+ findpath
+ while [ "$mypath" ]; do
+ VESPA_HOME=${mypath}
+ source_common_env && return
+ mypath=${mypath%/*}
+ done
+ echo "FATAL: missing VESPA_HOME environment variable"
+ echo "Could not locate $COMMON_ENV anywhere"
+ exit 1
+}
+
+findroot
+
+# END environment bootstrap section
+
+if [ $# -lt 2 ]; then
+ echo "USAGE $0 <template> <mlrxxx.xml>";
+ exit;
+fi
+
+JAVA=java
+CLASSPATH=$VESPA_HOME/lib/jars/xml2cpp.jar
+
+echo ${JAVA} -Xms64m -Xmx256m -cp ${CLASSPATH} com.yahoo.yst.libmlr.converter.DecisionTreeXmlToCpp -h $1 -i $2
+${JAVA} -Xms64m -Xmx256m -cp ${CLASSPATH} com.yahoo.yst.libmlr.converter.DecisionTreeXmlToCpp -h $1 -i $2
diff --git a/libmlr/pom.xml b/libmlr/pom.xml
new file mode 100644
index 00000000000..53ac04ffcc7
--- /dev/null
+++ b/libmlr/pom.xml
@@ -0,0 +1,53 @@
+<?xml version="1.0"?>
+<!-- Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>xml2cpp</artifactId>
+ <packaging>jar</packaging>
+ <version>1.0.0-SNAPSHOT</version>
+ <name>xml2cpp</name>
+ <description>Fork of xml2cppConverver with support for SS3 models.</description>
+ <dependencies>
+ </dependencies>
+ <build>
+ <finalName>${project.artifactId}</finalName>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-deploy-plugin</artifactId>
+ <version>2.5</version>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-compiler-plugin</artifactId>
+ <version>3.1</version>
+ <configuration>
+ <compilerArgs>
+ <arg>-Xlint:all</arg>
+ <arg>-Xlint:-serial</arg>
+ <arg>-Werror</arg>
+ </compilerArgs>
+ <showWarnings>true</showWarnings>
+ <source>1.8</source>
+ <target>1.8</target>
+ <showDeprecation>true</showDeprecation>
+ <showWarnings>true</showWarnings>
+ <optimize>true</optimize>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-install-plugin</artifactId>
+ <version>2.3.1</version>
+ <configuration>
+ <updateReleaseInfo>true</updateReleaseInfo>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ <properties>
+ <test.hide>true</test.hide>
+ <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
+ </properties>
+</project>
diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/DecisionTreeXmlToCpp.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/DecisionTreeXmlToCpp.java
new file mode 100644
index 00000000000..1dc6daeec41
--- /dev/null
+++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/DecisionTreeXmlToCpp.java
@@ -0,0 +1,539 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.yst.libmlr.converter;
+
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.FileReader;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.text.MessageFormat;
+import java.util.Date;
+import java.util.HashSet;
+
+import com.yahoo.yst.libmlr.converter.entity.EarlyExit;
+import com.yahoo.yst.libmlr.converter.entity.Epilog;
+import com.yahoo.yst.libmlr.converter.entity.FuncNormalize;
+import com.yahoo.yst.libmlr.converter.entity.FuncPolytransform;
+import com.yahoo.yst.libmlr.converter.entity.Function;
+import com.yahoo.yst.libmlr.converter.entity.InternalNode;
+import com.yahoo.yst.libmlr.converter.entity.ResponseNode;
+import com.yahoo.yst.libmlr.converter.entity.Tree;
+import com.yahoo.yst.libmlr.converter.entity.TreeNode;
+import com.yahoo.yst.libmlr.converter.entity.TreeNodeVisitor;
+import com.yahoo.yst.libmlr.converter.entity.TreenetFunction;
+import com.yahoo.yst.libmlr.converter.parser.DecisionTreeXmlException;
+import com.yahoo.yst.libmlr.converter.parser.MlrXmlParser;
+
+/**
+ * This class generates C++ from an MLR Decision Tree File
+ *
+ * @author allenwei
+ *
+ */
+public class DecisionTreeXmlToCpp {
+
+ private static final String INDENT_UNIT = " ";
+
+ private TreenetFunction tnFunc;
+ private String strCppFile;
+ private PrintWriter fpCpp;
+
+ private int nodeIdx;
+ private int gIndentLevel; // global indent level
+
+ public void setTnFunc(TreenetFunction tnFunc) {
+ this.tnFunc = tnFunc;
+ }
+
+ public DecisionTreeXmlToCpp(String file) {
+ strCppFile = file;
+ try {
+ fpCpp = new PrintWriter(
+ new BufferedWriter(
+ new FileWriter(strCppFile)));
+ } catch (IOException ioex) {
+ System.out.println("Cannot open " + strCppFile + " for write");
+ }
+ }
+
+ /**
+ * Generates C++ code.
+ */
+ public void genCode(String strHeaderFile) {
+ genCodeHeader(strHeaderFile);
+ gIndentLevel = 0;
+ setNodeIndex();
+ genCodeDefs();
+ genCodeFunc();
+ }
+
+ private void genCodeHeader(String strHeaderFile) {
+ String fmt = getFormatString(strHeaderFile);
+ String fileName = "mlr" + tnFunc.getFunctionId() + ".c";
+ int nTrees = tnFunc.getNumberOfTrees();
+ int nLeaves = tnFunc.getTree(0).getNumInternalNodes() + 1;
+ String header = MessageFormat.format(fmt, fileName, new Date(),
+ Integer.toString(nTrees),
+ Integer.toString(nLeaves));
+
+ gIndentLevel = 0;
+ printLn(0, header);
+ }
+
+ private void setNodeIndex() {
+
+ // set node id for each tree
+ int n = tnFunc.getNumberOfTrees();
+
+ SetNodeIndexVisitor nodeVisitor = new SetNodeIndexVisitor();
+ SetLeafIndexVisitor leafVisitor = new SetLeafIndexVisitor();
+ for (int i = 0; i < n; i++) {
+ nodeIdx = 0;
+ Tree tree = tnFunc.getTree(i);
+ traverseTree(tree.getRoot(), nodeVisitor);
+ traverseTree(tree.getRoot(), leafVisitor);
+ }
+ }
+
+ private void genCodeDefs() {
+ printLn(0, "#define TOTAL_TREES " + tnFunc.getNumberOfTrees());
+ printLn();
+
+ // const def for internal node labels
+ //genCodeTraverseTrees(0, null, new PrintNodeLabelDefVisitor(), null);
+
+ // const def for leaf node labels
+ //genCodeTraverseTrees(0, null, new PrintLeafLabelDefVisitor(), null);
+
+ genCodeNamespaceDefs();
+
+ // array init for internal nodes
+ genCodeTraverseTrees(1, "static const TreeNode nodes[] = {",
+ new PrintNodeInitVisitor(), "};");
+
+ // array init for leaf nodes
+ genCodeTraverseTrees(1, "static const double leaves[] = {",
+ new PrintLeafInitVisitor(), "};");
+
+ // array of tree size (number of internal nodes)
+ genCodeTreeSizeArrayInit(1);
+
+ genCodeEarlyExits(1);
+
+ }
+
+ private void genCodeNamespaceDefs() {
+ printLn(0, "namespace " + tnFunc.getNameSpace() + " {");
+ printLn();
+ printLn(0, "enum Feature {");
+ for (String f : tnFunc.getFeatureSet()) {
+ printLn(1, f + ",");
+ }
+ printLn(1, "NUMBER_FEATURES");
+ printLn(0, "}; /* enum */"); // end enum
+ printLn();
+ printLn(0, "} /* namespace */"); // end namespace
+ printLn();
+ }
+
+ private void genCodeFunc() {
+
+ // function definition
+ printLn(0, "double");
+ printLn(0, tnFunc.getFunctionName() + "(MlrScoreReq& msr) {");
+ printLn();
+
+ genFeatureArrayDecl(1);
+
+ // call traverseAll()
+ printLn(1, "msr.traverseAll(nodes, leaves, fValue, TOTAL_TREES, numNodes, meExits);");
+ printLn();
+
+ genCodeEpilog(1);
+
+ printLn(1, "return msr.getScore();");
+ printLn(0, "}"); // end function
+
+ fpCpp.close();
+ }
+
+ private void genFeatureArrayDecl(int indentInc) {
+ String ns = tnFunc.getNameSpace();
+
+ printLn(indentInc, "double fValue[" + ns + "::NUMBER_FEATURES];");
+ printLn();
+
+ // FNTM: Distinguished values
+ //printInd(1, "double FNTM = fValue[" + ns + "::FNTM] = msr.getFeature(rf::FNTM);");
+ //printLn();
+
+ HashSet<String> fSet = tnFunc.getFeatureSet();
+ //fSet.remove("FNTM");
+
+ // initialization of features
+ for (String f : fSet) {
+ printLn(indentInc, "fValue[" + ns + "::" + f + "] = msr.getFeature(rf::" + f + ");");
+ }
+
+ printLn();
+ }
+
+ /**
+ * Prints code by iterating over all trees and visiting each tree node with
+ * the TreeNodeVisitor.
+ *
+ * @param indentInc -
+ * indentation level of the first line
+ * @param start -
+ * code printed before iterations
+ * @param end -
+ * code printed after iterations
+ */
+ private void genCodeTraverseTrees(int indentInc, String start,
+ TreeNodeVisitor visitor, String end) {
+
+ if (start != null)
+ printLn(indentInc, start);
+
+ gIndentLevel += (indentInc + 1);
+ int n = tnFunc.getNumberOfTrees();
+ for (int i = 0; i < n; i++) {
+ Tree t = tnFunc.getTree(i);
+ printLn("// " + t.getId() + " " + t.getComment());
+
+ traverseTree(t.getRoot(), visitor);
+ printLn();
+ }
+ gIndentLevel -= (indentInc + 1);
+
+ if (end != null)
+ printLn(indentInc, end);
+
+ printLn();
+ }
+
+ private void genCodeTreeSizeArrayInit(int indentInc) {
+ String strDef = "static const int numNodes["
+ + tnFunc.getNumberOfTrees() + "] = {";
+ printLn(indentInc, strDef);
+
+ int n = tnFunc.getNumberOfTrees();
+ for (int i = 0; i < n; i++) {
+ String msg = tnFunc.getTree(i).getNumInternalNodes() + ", // " + i;
+ printLn(indentInc + 1, msg);
+ }
+
+ printLn(indentInc, "};");
+ printLn();
+ }
+
+ private void genCodeEarlyExits(int indentInc) {
+ printLn(indentInc, "static const MlrEarlyExit meExits[] = {");
+
+ int n = tnFunc.getNumEarlyExits();
+ for (int i = 0; i < n; i++) {
+ EarlyExit eex = tnFunc.getEarlyExit(i);
+ String strEarlyExit = "{" + eex.getTreeId() + ", "
+ + "decisiontree::OP_" + eex.getOperator().getId().toUpperCase() + ", "
+ + eex.getValue()
+ + "},";
+ printLn(indentInc + 1, strEarlyExit);
+ }
+
+ // always generate a sentinel element for terminal condition
+ String sentinel =
+ "{" + tnFunc.getNumberOfTrees()
+ + ", decisiontree::OP_NONE, 0.0}";
+ printLn(indentInc + 1, sentinel);
+ printLn(indentInc, "};");
+ printLn();
+ }
+
+ /**
+ * Currently only generate code for normalize()
+ */
+ private void genCodeEpilog(int indentInc) {
+ Epilog epilog = tnFunc.getEpilog();
+ if (epilog == null)
+ return;
+
+ Function func = epilog.getFunction();
+
+ if (func instanceof FuncNormalize) {
+ FuncNormalize funcNorm = (FuncNormalize) func;
+
+ if (funcNorm.getInvertMethod() != FuncNormalize.INV_NONE) {
+ genCodeInversion(indentInc, funcNorm);
+ }
+
+ if (funcNorm.isGenNormalize()) {
+ genCodeNormalize(indentInc, funcNorm);
+ }
+ } else if (func instanceof FuncPolytransform) {
+ FuncPolytransform funcPolytransform = (FuncPolytransform) func;
+ genCodePolytransform(indentInc, funcPolytransform);
+ } else {
+ throw new MlrCodeGenException("Unknown <epilogue> function: " + func.getClass().getName());
+ }
+ }
+
+ private void genCodeInversion(int indentInc, FuncNormalize funcNorm) {
+ if (funcNorm.getInvertMethod() == FuncNormalize.INV_INVERSION) {
+ printLn(indentInc, "msr.invert(" + funcNorm.getInvertedFrom() + ");");
+ printLn();
+ } else if (funcNorm.getInvertMethod() == FuncNormalize.INV_NEGATION) {
+ printLn(indentInc, "msr.negate();");
+ printLn();
+ }
+ }
+
+ private void genCodeNormalize(int indentInc, FuncNormalize funcNorm) {
+ StringBuilder sb = new StringBuilder();
+
+ printLn(indentInc, "msr.normalize(");
+
+ sb.append(funcNorm.getMean0()).append(", ")
+ .append(funcNorm.getSd0()).append(", ")
+ .append(funcNorm.getA0()).append(", ")
+ .append(funcNorm.getB0()).append(", ");
+ printLn(indentInc + 1, sb.toString());
+
+ sb.setLength(0);
+ sb.append(funcNorm.getMean1()).append(", ")
+ .append(funcNorm.getSd1()).append(", ")
+ .append(funcNorm.getA1()).append(", ")
+ .append(funcNorm.getB1()).append(", ");
+ printLn(indentInc + 1, sb.toString());
+
+ sb.setLength(0);
+ sb.append(funcNorm.getMean2()).append(", ")
+ .append(funcNorm.getSd2()).append(", ")
+ .append(funcNorm.getA2()).append(", ")
+ .append(funcNorm.getB2()).append(", ");
+ printLn(indentInc + 1, sb.toString());
+
+ sb.setLength(0);
+ sb.append(funcNorm.getMean3()).append(", ")
+ .append(funcNorm.getSd3()).append(", ")
+ .append(funcNorm.getA3()).append(", ")
+ .append(funcNorm.getB3());
+ printLn(indentInc + 1, sb.toString());
+
+ printLn(indentInc, ");");
+ printLn();
+
+ }
+
+ private void genCodePolytransform(int indentInc, FuncPolytransform funcPoly) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("msr.polytransform(");
+ sb.append(funcPoly.getA0()).append(", ")
+ .append(funcPoly.getA1()).append(", ")
+ .append(funcPoly.getA2()).append(", ")
+ .append(funcPoly.getA3()).append(");");
+ printLn(indentInc, sb.toString());
+ printLn();
+
+ }
+
+ // Utilities
+
+ private String getFormatString(String strFmtFile) {
+ try {
+ BufferedReader fp = new BufferedReader(
+ new FileReader(strFmtFile));
+ StringBuilder sb = new StringBuilder();
+ String line = null;
+ while ((line = fp.readLine()) != null) {
+ sb.append(line).append("\n");
+ }
+
+ String fmt = sb.toString();
+ fp.close();
+
+ return fmt;
+
+ } catch (FileNotFoundException e) {
+ throw new MlrCodeGenException(strFmtFile, e);
+ } catch (IOException ioe) {
+ throw new MlrCodeGenException("reading file " + strFmtFile,
+ ioe);
+ }
+ }
+
+ private void traverseTree(TreeNode node, TreeNodeVisitor v) {
+ v.visit(node);
+ if (node instanceof InternalNode) {
+ InternalNode dcNode = (InternalNode) node;
+ traverseTree(dcNode.getLeftNode(), v);
+ traverseTree(dcNode.getRightNode(), v);
+ }
+ }
+
+ private void printAppend(String str) {
+ fpCpp.print(str);
+ }
+
+ private void printLn(int inc, String str) {
+ int indent = gIndentLevel + inc;
+ for (int i = 0; i < indent; i++) {
+ fpCpp.print(INDENT_UNIT);
+ }
+ fpCpp.println(str);
+ }
+
+ private void printLn(String str) {
+ printLn(0, str);
+ }
+
+ private void printLn() {
+ fpCpp.println();
+ }
+
+ /**
+ * subclasses of TreeNodeVisitor
+ */
+
+ private class SetNodeIndexVisitor implements TreeNodeVisitor {
+
+ public void visit(TreeNode node) {
+ if (node instanceof InternalNode) {
+ node.setIndex(nodeIdx++);
+ }
+ }
+ }
+
+ private class SetLeafIndexVisitor implements TreeNodeVisitor {
+
+ public void visit(TreeNode node) {
+ if (node instanceof ResponseNode) {
+ node.setIndex(nodeIdx++);
+ }
+ }
+ }
+
+ /*
+ private class PrintNodeLabelDefVisitor implements TreeNodeVisitor {
+
+ public void visit(TreeNode node) {
+ if (node instanceof InternalNode) {
+ printInd(0, "#define " + node.label + " " + node.nodeId);
+ }
+ }
+ }
+
+ private class PrintLeafLabelDefVisitor implements TreeNodeVisitor {
+
+ public void visit(TreeNode node) {
+ if (node instanceof LeafNode) {
+ printInd(0, "#define " + node.label + " " + node.nodeId);
+ }
+ }
+ }
+ */
+
+ private class PrintNodeInitVisitor implements TreeNodeVisitor {
+
+ public void visit(TreeNode treeNode) {
+ if (treeNode instanceof InternalNode) {
+ InternalNode node = (InternalNode) treeNode;
+ int leftNodeIndex = node.getLeftNode().getIndex();
+ int rightNodeIndex = node.getRightNode().getIndex();
+
+ StringBuilder sb = new StringBuilder();
+ sb.append(node.getIndex() + " " + node.getId());
+ sb.append(" ").append(node.getComment());
+
+ String str = "{ " + tnFunc.getNameSpace() + "::" + node.getFeature() + ", "
+ + node.getValue() + ", " + leftNodeIndex + ", "
+ + rightNodeIndex + " }, // " + sb.toString();
+
+ printLn(str);
+ }
+ }
+ }
+
+ private class PrintLeafInitVisitor implements TreeNodeVisitor {
+
+ public void visit(TreeNode treeNode) {
+ if (treeNode instanceof ResponseNode) {
+ ResponseNode node = (ResponseNode) treeNode;
+ StringBuilder sb = new StringBuilder();
+ sb.append(node.getIndex() + " " + node.getId());
+ sb.append(" ").append(node.getComment());
+
+ String str = node.getResponse() + ", // " + sb.toString();
+ printLn(str);
+ }
+ }
+ }
+
+ public static void main(String[] args) {
+ String xmlFile = null;
+ String headerFile = null;
+ String cppFile = null;
+
+ int i = 0;
+ boolean hasErrors = false;
+ while (i < args.length && !hasErrors) {
+ String arg = args[i++];
+ if (arg.equals("-i")) {
+ if (i < args.length)
+ xmlFile = args[i++];
+ else
+ hasErrors = true;
+
+ } else if (arg.equals("-h")) {
+ if (i < args.length)
+ headerFile = args[i++];
+ else
+ hasErrors = true;
+
+ } else if (arg.equals("-o")) {
+ if (i < args.length)
+ cppFile = args[i++];
+ else
+ hasErrors = true;
+ }
+
+ }
+
+ if (xmlFile == null || headerFile == null)
+ hasErrors = true;
+
+ if (hasErrors) {
+ System.out.println("USAGE: java DecisionTreeXmlToCpp -i XML_file -h header_file [-o Cpp_file]");
+ return;
+ }
+
+ if (cppFile == null) {
+ if (xmlFile.endsWith(".xml")) {
+ int idx = xmlFile.lastIndexOf('.');
+ cppFile = xmlFile.substring(0, idx+1) + "c";
+ } else {
+ cppFile = xmlFile + ".c";
+ }
+ }
+
+ File fpCpp = new File(cppFile);
+ if (fpCpp.exists()) {
+ System.out.println(cppFile + " exits. Please rename and run again.");
+ return;
+ }
+
+ try {
+ MlrXmlParser parser = new MlrXmlParser();
+ DecisionTreeXmlToCpp toCpp = new DecisionTreeXmlToCpp(cppFile);
+
+ toCpp.setTnFunc((TreenetFunction) parser.parseXmlFile(xmlFile));
+ toCpp.genCode(headerFile);
+
+ } catch (DecisionTreeXmlException tnex) {
+ tnex.printStackTrace();
+ }
+
+ }
+}
diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/MlrCodeGenException.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/MlrCodeGenException.java
new file mode 100644
index 00000000000..9a30f82f78b
--- /dev/null
+++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/MlrCodeGenException.java
@@ -0,0 +1,18 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.yst.libmlr.converter;
+
+public class MlrCodeGenException extends RuntimeException {
+
+ public MlrCodeGenException() {
+ super();
+ }
+
+ public MlrCodeGenException(String msg) {
+ super(msg);
+ }
+
+ public MlrCodeGenException(String msg, Throwable cause) {
+ super(msg, cause);
+ }
+
+}
diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/XmlUtils.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/XmlUtils.java
new file mode 100644
index 00000000000..022f4d31a4e
--- /dev/null
+++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/XmlUtils.java
@@ -0,0 +1,65 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.yst.libmlr.converter;
+
+import java.util.ArrayList;
+
+import org.w3c.dom.Element;
+import org.w3c.dom.Node;
+
+public class XmlUtils {
+
+ public static Element getFirstChildElement(Node parent) {
+ if (parent == null)
+ return null;
+
+ Node nd = parent.getFirstChild();
+ while (nd != null) {
+ //System.out.println("type: " + nd.getNodeType() + " name: " + nd.getNodeName());
+ if (nd.getNodeType() == Node.ELEMENT_NODE) {
+ return (Element)nd;
+ }
+ nd = nd.getNextSibling();
+ }
+
+ return null;
+ }
+
+ public static Element getFirstChildElementByName(Node parent, String childName) {
+ if (parent == null)
+ return null;
+
+ Node nd = parent.getFirstChild();
+ while (nd != null) {
+ //System.out.println("type: " + nd.getNodeType() + " name: " + nd.getNodeName());
+ if (nd.getNodeType() == Node.ELEMENT_NODE
+ && nd.getNodeName().equals(childName)) {
+ return (Element)nd;
+ }
+ nd = nd.getNextSibling();
+ }
+
+ return null;
+ }
+
+ public static ArrayList<Element> getChildrenByName(Node parent, String childName) {
+ if (parent == null)
+ return null;
+
+ ArrayList<Element> list = new ArrayList<Element>();
+ Node nd = parent.getFirstChild();
+ while (nd != null) {
+ //System.out.println("type: " + nd.getNodeType() + " name: " + nd.getNodeName());
+ if (nd.getNodeType() == Node.ELEMENT_NODE
+ && nd.getNodeName().equals(childName)) {
+ list.add((Element)nd);
+ }
+ nd = nd.getNextSibling();
+ }
+
+ if (list.size() == 0)
+ return null;
+ else
+ return list;
+ }
+
+}
diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/EarlyExit.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/EarlyExit.java
new file mode 100644
index 00000000000..decc5e73985
--- /dev/null
+++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/EarlyExit.java
@@ -0,0 +1,28 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.yst.libmlr.converter.entity;
+
+
+public class EarlyExit {
+ protected int treeId;
+ protected Operator operator;
+ protected String value;
+
+ public EarlyExit(int tid, Operator op, String val) {
+ treeId = tid;
+ operator = op;
+ value = val;
+ }
+
+ public int getTreeId() {
+ return treeId;
+ }
+
+ public String getValue() {
+ return value;
+ }
+
+ public Operator getOperator() {
+ return operator;
+ }
+
+}
diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Epilog.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Epilog.java
new file mode 100644
index 00000000000..a2ea7835869
--- /dev/null
+++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Epilog.java
@@ -0,0 +1,16 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.yst.libmlr.converter.entity;
+
+public class Epilog {
+
+ protected Function function;
+
+ public Function getFunction() {
+ return function;
+ }
+
+ public void setFunction(Function f) {
+ this.function = f;
+ }
+
+}
diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/FuncNormalize.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/FuncNormalize.java
new file mode 100644
index 00000000000..861c99dc3e3
--- /dev/null
+++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/FuncNormalize.java
@@ -0,0 +1,211 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.yst.libmlr.converter.entity;
+
+public class FuncNormalize implements Function {
+
+ public static int INV_NONE = 0;
+ public static int INV_INVERSION = 1;
+ public static int INV_NEGATION = 2;
+
+ private int invertMethod = INV_NONE;
+ private String invertedFrom;
+
+ /*
+ * The following parameters are type String to preserve precision.
+ */
+ private boolean pGenNormalize;
+ protected String mean0;
+ protected String mean1;
+ protected String mean2;
+ protected String mean3;
+ protected String sd0;
+ protected String sd1;
+ protected String sd2;
+ protected String sd3;
+ protected String a0;
+ protected String a1;
+ protected String a2;
+ protected String a3;
+ protected String b0;
+ protected String b1;
+ protected String b2;
+ protected String b3;
+
+ public FuncNormalize() {}
+
+ public int getInvertMethod() {
+ return invertMethod;
+ }
+
+ public void setInvertMethod(int inv) {
+ this.invertMethod = inv;
+ }
+
+ public String getInvertedFrom() {
+ return invertedFrom;
+ }
+
+ public void setInvertedFrom(String invertedFrom) {
+ this.invertedFrom = invertedFrom;
+ }
+
+ public boolean isGenNormalize() {
+ return pGenNormalize;
+ }
+
+ public void setpDoNormalize(boolean doNormalize) {
+ this.pGenNormalize = doNormalize;
+ }
+
+ public String getMean0() {
+ return mean0;
+ }
+
+ public void setMean0(String mean0) {
+ this.mean0 = mean0;
+ }
+
+ public String getMean1() {
+ return mean1;
+ }
+
+ public void setMean1(String mean1) {
+ this.mean1 = mean1;
+ }
+
+ public String getMean2() {
+ return mean2;
+ }
+
+ public void setMean2(String mean2) {
+ this.mean2 = mean2;
+ }
+
+ public String getMean3() {
+ return mean3;
+ }
+
+ public void setMean3(String mean3) {
+ this.mean3 = mean3;
+ }
+
+ public String getSd0() {
+ return sd0;
+ }
+
+ public void setSd0(String sd0) {
+ this.sd0 = sd0;
+ }
+
+ public String getSd1() {
+ return sd1;
+ }
+
+ public void setSd1(String sd1) {
+ this.sd1 = sd1;
+ }
+
+ public String getSd2() {
+ return sd2;
+ }
+
+ public void setSd2(String sd2) {
+ this.sd2 = sd2;
+ }
+
+ public String getSd3() {
+ return sd3;
+ }
+
+ public void setSd3(String sd3) {
+ this.sd3 = sd3;
+ }
+
+ public String getA0() {
+ return a0;
+ }
+
+ public void setA0(String a0) {
+ this.a0 = a0;
+ }
+
+ public String getA1() {
+ return a1;
+ }
+
+ public void setA1(String a1) {
+ this.a1 = a1;
+ }
+
+ public String getA2() {
+ return a2;
+ }
+
+ public void setA2(String a2) {
+ this.a2 = a2;
+ }
+
+ public String getA3() {
+ return a3;
+ }
+
+ public void setA3(String a3) {
+ this.a3 = a3;
+ }
+
+ public String getB0() {
+ return b0;
+ }
+
+ public void setB0(String b0) {
+ this.b0 = b0;
+ }
+
+ public String getB1() {
+ return b1;
+ }
+
+ public void setB1(String b1) {
+ this.b1 = b1;
+ }
+
+ public String getB2() {
+ return b2;
+ }
+
+ public void setB2(String b2) {
+ this.b2 = b2;
+ }
+
+ public String getB3() {
+ return b3;
+ }
+
+ public void setB3(String b3) {
+ this.b3 = b3;
+ }
+
+ public boolean validateParams() {
+ if (mean0 != null) {
+ if (mean1 == null || mean2 == null || mean3 == null ||
+ sd0 == null || sd1 == null || sd2 == null || sd3 == null ||
+ a0 == null || a1 == null || a2 == null || a3 == null ||
+ b0 == null || b1 == null || b2 == null || b3 == null) {
+ return false;
+ } else {
+ pGenNormalize = true;
+ }
+ } else { // mean0 == null
+ if (mean1 != null || mean2 != null || mean3 != null ||
+ sd0 != null || sd1 != null || sd2 != null || sd3 != null ||
+ a0 != null || a1 != null || a2 != null || a3 != null ||
+ b0 != null || b1 != null || b2 != null || b3 != null) {
+ return false;
+ } else {
+ pGenNormalize = false;
+ }
+ }
+
+ return true;
+ }
+}
diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/FuncPolytransform.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/FuncPolytransform.java
new file mode 100644
index 00000000000..9925d60cd93
--- /dev/null
+++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/FuncPolytransform.java
@@ -0,0 +1,50 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.yst.libmlr.converter.entity;
+
+public class FuncPolytransform implements Function {
+ /*
+ * The following parameters are type String to preserve precision.
+ */
+ protected String a0;
+ protected String a1;
+ protected String a2;
+ protected String a3;
+
+ public FuncPolytransform() {}
+
+ public String getA0() {
+ return a0;
+ }
+
+ public void setA0(String a0) {
+ this.a0 = a0;
+ }
+
+ public String getA1() {
+ return a1;
+ }
+
+ public void setA1(String a1) {
+ this.a1 = a1;
+ }
+
+ public String getA2() {
+ return a2;
+ }
+
+ public void setA2(String a2) {
+ this.a2 = a2;
+ }
+
+ public String getA3() {
+ return a3;
+ }
+
+ public void setA3(String a3) {
+ this.a3 = a3;
+ }
+
+ public boolean validateParams() {
+ return (a0 != null && a1 != null && a2 != null && a3 != null);
+ }
+}
diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Function.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Function.java
new file mode 100644
index 00000000000..e2649652f52
--- /dev/null
+++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Function.java
@@ -0,0 +1,8 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.yst.libmlr.converter.entity;
+
+public interface Function {
+
+ public boolean validateParams();
+
+}
diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/InternalNode.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/InternalNode.java
new file mode 100644
index 00000000000..a63361d1d20
--- /dev/null
+++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/InternalNode.java
@@ -0,0 +1,61 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.yst.libmlr.converter.entity;
+
+
+public class InternalNode extends TreeNode {
+
+ private String feature;
+ private String op;
+ private String value;
+ private TreeNode left; // true
+ private TreeNode right; // false
+
+ public InternalNode(String i, String c, String f, String v, TreeNode lf, TreeNode rt) {
+ super(i, c);
+ feature = f;
+ value = v;
+ left = lf;
+ right = rt;
+ }
+
+ public String getFeature() {
+ return feature;
+ }
+
+ public void setFeature(String feature) {
+ this.feature = feature;
+ }
+
+ public String getOp() {
+ return op;
+ }
+
+ public void setOp(String op) {
+ this.op = op;
+ }
+
+ public String getValue() {
+ return value;
+ }
+
+ public void setValue(String value) {
+ this.value = value;
+ }
+
+ public TreeNode getLeftNode() {
+ return left;
+ }
+
+ public void setLeftNode(TreeNode left) {
+ this.left = left;
+ }
+
+ public TreeNode getRightNode() {
+ return right;
+ }
+
+ public void setRightNode(TreeNode right) {
+ this.right = right;
+ }
+
+}
diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/MlrFunction.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/MlrFunction.java
new file mode 100644
index 00000000000..760274dba0e
--- /dev/null
+++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/MlrFunction.java
@@ -0,0 +1,31 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.yst.libmlr.converter.entity;
+
+
+public abstract class MlrFunction {
+ protected String functionName;
+ protected String funcId; // numeric function id
+ protected String featureDefFile;
+ protected Epilog epilog;
+
+ public String getFunctionName() {
+ return functionName;
+ }
+
+ public String getFunctionId() {
+ return funcId;
+ }
+
+ public String getFeatureDefFile() {
+ return featureDefFile;
+ }
+
+ public Epilog getEpilog() {
+ return epilog;
+ }
+
+ public void setEpilog(Epilog epilog) {
+ this.epilog = epilog;
+ }
+
+}
diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Operator.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Operator.java
new file mode 100644
index 00000000000..9052ee8ecc3
--- /dev/null
+++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Operator.java
@@ -0,0 +1,36 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.yst.libmlr.converter.entity;
+
+public enum Operator {
+ EQ("eq"),
+ NEQ("neq"),
+ GT("gt"),
+ GEQ("geq"),
+ LT("lt"),
+ LEQ("leq");
+
+ private final String id;
+
+ Operator(String id) {
+ this.id = id;
+ }
+
+ public static Operator parse(String str) {
+ for (Operator op : Operator.values()) {
+ if (op.id.equals(str))
+ return op;
+ }
+ throw new IllegalArgumentException();
+ }
+
+ public String getId() {
+ return id;
+ }
+
+ public static void main(String[] args) {
+ Operator op = Operator.parse("gt");
+ System.out.println("operator.toString = " + op.toString());
+ System.out.println("operator = " + op.getId());
+ }
+
+}
diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/ResponseNode.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/ResponseNode.java
new file mode 100644
index 00000000000..b2d49023458
--- /dev/null
+++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/ResponseNode.java
@@ -0,0 +1,21 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.yst.libmlr.converter.entity;
+
+public class ResponseNode extends TreeNode {
+
+ private double response;
+
+ public ResponseNode(String i, String c, double r) {
+ super(i, c);
+ response = r;
+ }
+
+ public double getResponse() {
+ return response;
+ }
+
+ public void setResponse(double response) {
+ this.response = response;
+ }
+
+}
diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Tree.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Tree.java
new file mode 100644
index 00000000000..ba9fb278cfc
--- /dev/null
+++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Tree.java
@@ -0,0 +1,57 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.yst.libmlr.converter.entity;
+
+
+public class Tree {
+
+ private String id;
+ private String comment;
+
+ private InternalNode root;
+ private int nInternalNodes; // number of internal nodes
+
+
+ public Tree() {}
+
+ public Tree(String id, String comment) {
+ this.id = id;
+ this.comment = comment;
+ }
+
+ public String getId() {
+ return id;
+ }
+
+ public void setId(String id) {
+ this.id = id;
+ }
+
+ public String getComment() {
+ return (comment == null ? "" : comment);
+ }
+
+ public void setComment(String comment) {
+ this.comment = comment;
+ }
+
+ public InternalNode getRoot() {
+ return root;
+ }
+
+ public void setRoot(InternalNode root) {
+ this.root = root;
+ }
+
+ public int getNumInternalNodes() {
+ return nInternalNodes;
+ }
+
+ public void incrInteralNodes() {
+ nInternalNodes++;
+ }
+
+ public void setNumInternalNodes(int n) {
+ nInternalNodes = n;
+ }
+
+}
diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreeNode.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreeNode.java
new file mode 100644
index 00000000000..158d4f8f788
--- /dev/null
+++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreeNode.java
@@ -0,0 +1,39 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.yst.libmlr.converter.entity;
+
+public class TreeNode {
+
+ private String id;
+ private String comment;
+ private int idx;
+
+ public TreeNode(String i, String c) {
+ id = i;
+ comment = c;
+ }
+
+ public String getId() {
+ return id;
+ }
+
+ public void setId(String id) {
+ this.id = id;
+ }
+
+ public String getComment() {
+ return (comment == null ? "" : comment);
+ }
+
+ public void setComment(String comment) {
+ this.comment = comment;
+ }
+
+ public int getIndex() {
+ return idx;
+ }
+
+ public void setIndex(int idx) {
+ this.idx = idx;
+ }
+
+}
diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreeNodeVisitor.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreeNodeVisitor.java
new file mode 100644
index 00000000000..5ced96fb04c
--- /dev/null
+++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreeNodeVisitor.java
@@ -0,0 +1,8 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.yst.libmlr.converter.entity;
+
+public interface TreeNodeVisitor {
+
+ public void visit(TreeNode node);
+
+}
diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreenetFunction.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreenetFunction.java
new file mode 100644
index 00000000000..2f9d4203d50
--- /dev/null
+++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreenetFunction.java
@@ -0,0 +1,103 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.yst.libmlr.converter.entity;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+import org.w3c.dom.Attr;
+import org.w3c.dom.Document;
+import org.w3c.dom.Element;
+import org.w3c.dom.NamedNodeMap;
+import org.w3c.dom.Node;
+import org.w3c.dom.NodeList;
+
+import com.yahoo.yst.libmlr.converter.parser.DecisionTreeXmlException;
+
+public class TreenetFunction extends MlrFunction {
+
+ private String ns; // namespace
+ private ArrayList<Tree> treeArylst;
+ private HashSet<String> featureSet;
+ private HashSet<String> labelSet;
+ protected ArrayList<EarlyExit> earlyExitArylst;
+
+
+ public TreenetFunction() {
+ treeArylst = new ArrayList<Tree>(500);
+ featureSet = new HashSet<String>();
+ labelSet = new HashSet<String>();
+ earlyExitArylst = new ArrayList<EarlyExit>(5);
+ }
+
+ public void setFunctionName(String id) {
+ funcId = id;
+ functionName = "mlr" + id;
+ /*
+ Pattern p = Pattern.compile("[^\\d]+(\\d+)\\w*");
+ Matcher m = p.matcher(functionName);
+ if (!m.matches())
+ throw new IllegalArgumentException("not a valid functionName");
+
+ funcId = m.group(1);
+ */
+ ns = "mlr" + funcId + "ns";
+ }
+
+ public String getNameSpace() {
+ return ns;
+ }
+
+ public int getNumberOfTrees() {
+ return treeArylst.size();
+ }
+
+ public Tree getTree(int i) {
+ return treeArylst.get(i);
+ }
+
+ public void setTree(Tree t) {
+ treeArylst.add(t);
+ }
+
+ public HashSet<String> getFeatureSet() {
+ return featureSet;
+ }
+
+ public HashSet<String> getLabelSet() {
+ return labelSet;
+ }
+
+ public void addFeature(String f) {
+ featureSet.add(f);
+ }
+
+ public void addLabel(String lbl) {
+ if (labelSet.contains(lbl))
+ throw new DecisionTreeXmlException("Label " + lbl + " existed.");
+ labelSet.add(lbl);
+ }
+
+ public void removeLabelSet() {
+ labelSet = null;
+ }
+
+ public void getAllFeatures() {
+ for (String f: featureSet) {
+ System.out.println(f);
+ }
+ }
+
+ public void addEarlyExit(EarlyExit earx) {
+ earlyExitArylst.add(earx);
+ }
+
+ public int getNumEarlyExits() {
+ return earlyExitArylst.size();
+ }
+
+ public EarlyExit getEarlyExit(int i) {
+ return earlyExitArylst.get(i);
+ }
+}
diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/parser/DecisionTreeXmlException.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/parser/DecisionTreeXmlException.java
new file mode 100644
index 00000000000..3f792ef991f
--- /dev/null
+++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/parser/DecisionTreeXmlException.java
@@ -0,0 +1,17 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.yst.libmlr.converter.parser;
+
+public class DecisionTreeXmlException extends RuntimeException {
+
+ public DecisionTreeXmlException() {
+ super();
+ }
+
+ public DecisionTreeXmlException(String msg) {
+ super(msg);
+ }
+
+ public DecisionTreeXmlException(String msg, Throwable cause) {
+ super(msg, cause);
+ }
+}
diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/parser/MlrXmlParser.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/parser/MlrXmlParser.java
new file mode 100644
index 00000000000..a689aefcb98
--- /dev/null
+++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/parser/MlrXmlParser.java
@@ -0,0 +1,435 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.yst.libmlr.converter.parser;
+
+import com.yahoo.yst.libmlr.converter.XmlUtils;
+import com.yahoo.yst.libmlr.converter.entity.*;
+import org.w3c.dom.*;
+
+import javax.xml.parsers.DocumentBuilder;
+import javax.xml.parsers.DocumentBuilderFactory;
+import javax.xml.parsers.ParserConfigurationException;
+import java.io.File;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.logging.Logger;
+
+/**
+ * Parses Treenet output V5 into Abstract Treenet XML File format.
+ *
+ * @author allenwei
+ *
+ */
+public class MlrXmlParser {
+
+ private static Logger logger = Logger.getLogger("com.yahoo.yst.libmlrutil.TnXmlParser");
+ private static final String errNormAttr = "<Normalize>: All or none of attributes mean0-3, sd0-3, a0-3, b0-3 are required";
+ private static final String errPolyAttr = "<Normalize>: All or none of attributes a0-3 are required";
+
+ private HashSet<String> treeIdSet = new HashSet<String>(500);
+ private HashSet<String> nodeIdSet = new HashSet<String>(10000);
+ private HashSet<String> respIdSet = new HashSet<String>(10000);
+
+ public MlrFunction parseXmlFile(String fileName) throws DecisionTreeXmlException {
+
+ File file = new File(fileName);
+ if (!file.exists()) {
+ String errMsg = fileName + " does not exist.";
+ logErrors(errMsg);
+ throw new DecisionTreeXmlException(errMsg);
+ }
+
+ DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
+ DocumentBuilder docBuilder = null;
+
+ try {
+ docBuilder = dbf.newDocumentBuilder();
+ Document doc = docBuilder.parse(file);
+ Element eltMlrFnc = doc.getDocumentElement();
+ if (!eltMlrFnc.getTagName().equals("MlrFunction")) {
+ String errMsg = "The top element must be <MlrFunction>";
+ logErrors(errMsg);
+ throw new DecisionTreeXmlException(errMsg);
+ }
+
+ return parseMlrFunction(eltMlrFnc);
+
+ //System.out.println("features: " + tnFunc.getFeatureSet().size());
+ //System.out.println("labels: " + tnFunc.getLabelSet().size());
+
+ } catch (ParserConfigurationException pe) {
+ String errMsg = "Cannot construct XML DocumentBuilder";
+ logErrors(pe, errMsg);
+ throw new DecisionTreeXmlException(errMsg, pe);
+
+ } catch (DecisionTreeXmlException te) {
+ throw te;
+
+ } catch (Exception ex) {
+ String errMsg ="Errors found parsing XML";
+ logErrors(errMsg);
+ ex.printStackTrace();
+ throw new DecisionTreeXmlException(errMsg, ex);
+ }
+ }
+
+ private MlrFunction parseMlrFunction(Element eltMlrFnc) {
+ MlrFunction mlrFunc = null;
+
+ Element eltDecisionTree = getFirstChildElementByName(eltMlrFnc, "DecisionTree", true);
+ TreenetFunction tnFunc = new TreenetFunction();
+ String id = getAttribute(eltMlrFnc, "name", true);
+ try {
+ Integer.parseInt(id);
+ } catch (NumberFormatException nex) {
+ throw new DecisionTreeXmlException("name in <MlrFunction> should be an integer " + id , nex);
+ }
+ tnFunc.setFunctionName(id);
+ parseDecisionTree(eltDecisionTree, tnFunc);
+
+ mlrFunc = tnFunc;
+
+ if (mlrFunc != null) {
+ Element eltEpilog = getFirstChildElementByName(eltMlrFnc, "Epilogue", false);
+ if (eltEpilog != null) {
+ Epilog epilog = parseEpilog(eltEpilog);
+ mlrFunc.setEpilog(epilog);
+ }
+ }
+
+ return mlrFunc;
+ }
+
+ private void parseDecisionTree(Element eltDecisionTree, TreenetFunction tnFunc) {
+ parseForest(getFirstChildElementByName(eltDecisionTree, "Forest", true), tnFunc);
+
+ Element eltEarlyExits = getFirstChildElementByName(eltDecisionTree, "EarlyExits", false);
+ if (eltEarlyExits != null)
+ parseEarlyExits(eltEarlyExits, tnFunc);
+ }
+
+ private void parseForest(Element eltForest, TreenetFunction tnFunc) {
+ //String strTotal = eltForest.getAttribute("total");
+ // tnFunc.setNumTrees(Integer.parseInt(eltForest.getAttribute("total")));
+
+ ArrayList<Element> nl = XmlUtils.getChildrenByName(eltForest, "Tree");
+ int n = nl.size();
+ if (n == 0)
+ throw new DecisionTreeXmlException("<Forest> should have at least one <Tree> element");
+
+ for (int i = 0; i < n; i++) {
+ parseTree(nl.get(i), tnFunc);
+ }
+ }
+
+ private void parseTree(Element eltTree, TreenetFunction tnFunc) {
+ String comment = getAttribute(eltTree, "comment", false);
+ String id = getAttribute(eltTree, "id", true);
+ if (treeIdSet.contains(id))
+ throw new DecisionTreeXmlException("Duplicate tree id " + id);
+ else
+ treeIdSet.add(id);
+
+ Tree tr = new Tree(id, comment);
+ tnFunc.setTree(tr);
+
+ // DEBUG
+ //System.out.println("tree " + id);
+ ArrayList<Element> list = XmlUtils.getChildrenByName(eltTree, "Node");
+ if (list == null || list.size() != 1)
+ throw new DecisionTreeXmlException("<Tree> should have exactly one root <Node> element");
+
+ Element eltNode = list.get(0);
+ InternalNode root = parseInternalNode(eltNode, tnFunc, tr);
+ tr.setRoot(root);
+ }
+
+ private TreeNode parseTreeNode(Element eltNode, TreenetFunction tnFunc, Tree tr) {
+ String tag = eltNode.getNodeName();
+ if (tag.equals("Node"))
+ return parseInternalNode(eltNode, tnFunc, tr);
+ else if (tag.equals("Response"))
+ return parseResponse(eltNode, tnFunc);
+ else
+ throw new DecisionTreeXmlException("ERROR: unknown tag <" + tag + ">. Should never reach here.");
+ }
+
+ private InternalNode parseInternalNode(Element eltNode, TreenetFunction tnFunc, Tree tr) {
+ tr.incrInteralNodes();
+
+ String id = getAttribute(eltNode, "id", true);
+ if (nodeIdSet.contains(id))
+ throw new DecisionTreeXmlException("Duplicate Internal Node id " + id);
+ else
+ nodeIdSet.add(id);
+
+ String comment = getAttribute(eltNode, "comment", false);
+
+ String feature = getAttribute(eltNode, "feature", true);
+ tnFunc.addFeature(feature);
+
+ String value = getAttribute(eltNode, "value", true);
+ try {
+ Double.parseDouble(value);
+ } catch (NumberFormatException nfex) {
+ String errMsg = "Node " + id + ": value not a number: " + value;
+ throw new DecisionTreeXmlException(errMsg, nfex);
+ }
+
+ ArrayList<Node> childNodes = new ArrayList<Node>(5);
+
+ NodeList nl = eltNode.getChildNodes();
+ int n = nl.getLength();
+ Node nd;
+ for (int i = 0; i < n; i++) {
+ nd = nl.item(i);
+ if (nd.getNodeType() == Node.ELEMENT_NODE) {
+ String tag = nd.getNodeName();
+ if (tag.equals("Node") || tag.equals("Response")) {
+ childNodes.add(nd);
+ }
+ }
+ }
+
+ int numChildNodes = childNodes.size();
+ if (numChildNodes != 2) {
+ String strNode = "Node: id=" + id + " " + feature + " " + value;
+ String errMsgNodes = "ERROR: A <Node> element should have exactly 2 child nodes. A child node can be <Node> or <Response>. " + strNode;
+ throw new DecisionTreeXmlException(errMsgNodes);
+ }
+
+ TreeNode left = parseTreeNode((Element)childNodes.get(0), tnFunc, tr);
+ TreeNode right = parseTreeNode((Element)childNodes.get(1), tnFunc, tr);
+
+ return new InternalNode(id, comment, feature, value, left, right);
+ }
+
+ private ResponseNode parseResponse(Element eltResponse, TreenetFunction tnFunc) {
+ String id = getAttribute(eltResponse, "id", true);
+ if (respIdSet.contains(id))
+ throw new DecisionTreeXmlException("Duplicate Response Node id " + id);
+ else
+ respIdSet.add(id);
+
+ String comment = getAttribute(eltResponse, "comment", false);
+
+ String strValue = eltResponse.getAttribute("value");
+ double value;
+ try {
+ value = Double.parseDouble(strValue);
+ } catch (NumberFormatException ne) {
+ throw new DecisionTreeXmlException("Response Node " + id + " does not contain a double value. value=" + strValue);
+ }
+
+ return new ResponseNode(id, comment, value);
+ }
+
+ private void parseEarlyExits(Element eltEarlyExits, TreenetFunction tnFunc) {
+ ArrayList<Element> nl = XmlUtils.getChildrenByName(eltEarlyExits, "Exit");
+ if (nl != null) {
+ int n = nl.size();
+ for (int i = 0; i < n; i++) {
+ parseExit(nl.get(i), tnFunc);
+ }
+ }
+ }
+
+ private void parseExit(Element eltExit, TreenetFunction tnFunc) {
+ String attr = getAttribute(eltExit, "tree", true);
+ int tree;
+ try {
+ tree = Integer.parseInt(attr);
+ } catch (NumberFormatException ex) {
+ String errMsg = "Invalid value for attribute tree: " + attr;
+ throw new DecisionTreeXmlException(errMsg);
+ }
+
+ String strValue = getAttribute(eltExit, "value", true);
+ try {
+ Double.parseDouble(attr);
+ } catch (NumberFormatException ex) {
+ String errMsg = "Invalid value for attribute value: " + attr;
+ throw new DecisionTreeXmlException(errMsg);
+ }
+
+ attr = getAttribute(eltExit, "op", true);
+ Operator op;
+ try {
+ op = Operator.parse(attr);
+ } catch (IllegalArgumentException ex) {
+ String errMsg = "Invalid value for attribute op: " + attr;
+ throw new DecisionTreeXmlException(errMsg);
+ }
+
+ tnFunc.addEarlyExit(new EarlyExit(tree, op, strValue));
+
+ }
+
+ private Epilog parseEpilog(Element eltEpilog) {
+ Element eltOp = XmlUtils.getFirstChildElement(eltEpilog);
+ if (eltOp.getNodeName().equals("Normalize")) {
+ try {
+ return parseNormalize(eltOp);
+ } catch (DecisionTreeXmlException e) {
+ return null;
+ }
+ } else if (eltOp.getNodeName().equals("Polytransform")) {
+ return parsePolytransform(eltOp);
+ }
+ else {
+ return null;
+ }
+ }
+
+ private Epilog parseNormalize(Element eltNorm) {
+ Epilog epilog = new Epilog();
+ FuncNormalize func = new FuncNormalize();
+ epilog.setFunction(func);
+
+ String strIsInv = getBoolAttribute(eltNorm, "isInverted", false);
+ if (strIsInv != null && strIsInv.equals("true")) {
+ String strInvFrom = getDoubleAttribute(eltNorm, "invertedFrom", true);
+ func.setInvertMethod(FuncNormalize.INV_INVERSION);
+ func.setInvertedFrom(strInvFrom);
+ }
+
+ String strIsNeg = getAttribute(eltNorm, "isNegated", false);
+ if (strIsNeg != null && strIsNeg.equals("true")) {
+ if (func.getInvertMethod() == FuncNormalize.INV_NONE)
+ func.setInvertMethod(FuncNormalize.INV_NEGATION);
+ else
+ throw new DecisionTreeXmlException("cannot have both isInverted and isNegated defined in element <Normalize>");
+ }
+
+ func.setMean0(getDoubleAttribute(eltNorm, "mean0", false));
+ func.setMean1(getDoubleAttribute(eltNorm, "mean1", false));
+ func.setMean2(getDoubleAttribute(eltNorm, "mean2", false));
+ func.setMean3(getDoubleAttribute(eltNorm, "mean3", false));
+ func.setSd0(getDoubleAttribute(eltNorm, "sd0", false));
+ func.setSd1(getDoubleAttribute(eltNorm, "sd1", false));
+ func.setSd2(getDoubleAttribute(eltNorm, "sd2", false));
+ func.setSd3(getDoubleAttribute(eltNorm, "sd3", false));
+ func.setA0(getDoubleAttribute(eltNorm, "a0", false));
+ func.setA1(getDoubleAttribute(eltNorm, "a1", false));
+ func.setA2(getDoubleAttribute(eltNorm, "a2", false));
+ func.setA3(getDoubleAttribute(eltNorm, "a3", false));
+ func.setB0(getDoubleAttribute(eltNorm, "b0", false));
+ func.setB1(getDoubleAttribute(eltNorm, "b1", false));
+ func.setB2(getDoubleAttribute(eltNorm, "b2", false));
+ func.setB3(getDoubleAttribute(eltNorm, "b3", false));
+
+ if (!func.validateParams())
+ throw new DecisionTreeXmlException(errNormAttr);
+
+ return epilog;
+ }
+
+ private Epilog parsePolytransform(Element eltOp) {
+ Epilog epilog = new Epilog();
+ FuncPolytransform func = new FuncPolytransform();
+
+ func.setA0(getDoubleAttribute(eltOp, "a0", false));
+ func.setA1(getDoubleAttribute(eltOp, "a1", false));
+ func.setA2(getDoubleAttribute(eltOp, "a2", false));
+ func.setA3(getDoubleAttribute(eltOp, "a3", false));
+
+ if (!func.validateParams())
+ throw new DecisionTreeXmlException(errPolyAttr);
+
+ epilog.setFunction(func);
+ return epilog;
+ }
+
+ /**
+ * Checks if the attribute name exists.
+ *
+ * @param eltNorm - the element containing the attribute
+ * @param attr - attribute name
+ * @return true if the attribute exists; or false, otherwise.
+ */
+ private boolean checkAttrExist(Element eltNorm, String attr) {
+ Attr attrNode = eltNorm.getAttributeNode(attr);
+ if (attrNode != null)
+ return true;
+ else
+ return false;
+ }
+
+ /**
+ * Returns the value of attribute.
+ *
+ * @param elt
+ * @param attr
+ * @param reqd
+ * @return If the attribute exists, the value of the attribute is returned, otherwise null is returned.
+ */
+ private String getAttribute(Element elt, String attr, boolean reqd) {
+ Attr attrNode = elt.getAttributeNode(attr);
+ String val = null;
+ if (attrNode != null)
+ val = elt.getAttribute(attr);
+
+ if (reqd && (val == null || val.equals("")))
+ throw new DecisionTreeXmlException(elt.getTagName() + ": missing required attribute " + attr);
+ return val;
+ }
+
+ private String getBoolAttribute(Element elt, String attr, boolean reqd) {
+ String strVal = getAttribute(elt, attr, reqd);
+
+ if (strVal == null ||
+ ((strVal.equals("true") || strVal.equals("false")))) {
+ return strVal;
+ } else {
+ String errMsg = "Attribute " + attr + " in Element " + elt.getTagName() + " is not a valid boolean value: " + strVal;
+ throw new DecisionTreeXmlException(errMsg);
+ }
+ }
+
+ private String getIntAttribute(Element elt, String attr, boolean reqd) {
+ String strVal = getAttribute(elt, attr, reqd);
+ try {
+ if (strVal != null)
+ Integer.parseInt(strVal);
+ return strVal;
+ } catch (NumberFormatException ne) {
+ String errMsg = "Attribute " + attr + " in Element " + elt.getTagName() + " is not a valid integer: " + strVal;
+ throw new DecisionTreeXmlException(errMsg);
+ }
+ }
+
+ private String getDoubleAttribute(Element elt, String attr, boolean reqd) {
+ String strVal = getAttribute(elt, attr, reqd);
+ try {
+ if (strVal != null)
+ Double.parseDouble(strVal);
+ return strVal;
+ } catch (NumberFormatException ne) {
+ String errMsg = "Attribute " + attr + " in Element " + elt.getTagName() + " is not a valid double: " + strVal;
+ throw new DecisionTreeXmlException(errMsg);
+ }
+ }
+
+ private Element getFirstChildElementByName(Element parent, String childName, boolean reqd) {
+ Element elt = XmlUtils.getFirstChildElementByName(parent, childName);
+ if (elt == null && reqd)
+ throw new DecisionTreeXmlException(elt.getTagName() + ": missing required element " + childName);
+ return elt;
+ }
+
+ private static void logErrors(String msg) {
+ logger.severe(msg);
+ System.out.println(msg);
+ }
+
+ private static void logErrors(Exception ex, String msg) {
+ String errMsg = ex.getClass().getName() + " " + ex.getMessage() + ": " + msg;
+ logger.severe(errMsg);
+ System.out.println(errMsg);
+ }
+
+ public static void main(String[] args) {
+ String fileName = "C:\\yst\\libMLR_framework\\mlr3135.xml";
+ new MlrXmlParser().parseXmlFile(fileName);
+ }
+
+}
diff --git a/libmlr/src/main/java/config/header_template.txt b/libmlr/src/main/java/config/header_template.txt
new file mode 100644
index 00000000000..16b6feaed2c
--- /dev/null
+++ b/libmlr/src/main/java/config/header_template.txt
@@ -0,0 +1,17 @@
+/**
+ * File: {0}
+ * Package: search/secore/libs/mlr
+ * Description:
+ *
+ * Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+ *
+ * Confidential
+ *
+ * Conversion Time: {1,date,long} {1,time,long}
+ *
+ * MODEL_SIZE: {2} trees x {3} leaf nodes
+ */
+#include "mlrfeatures.h"
+#include "mlrscorereq.h"
+#include "mlrfns.h"
+