diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-06-15 23:09:44 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-06-15 23:09:44 +0200 |
commit | 72231250ed81e10d66bfe70701e64fa5fe50f712 (patch) | |
tree | 2728bba1131a6f6e5bdf95afec7d7ff9358dac50 /libmlr |
Publish
Diffstat (limited to 'libmlr')
22 files changed, 1887 insertions, 0 deletions
diff --git a/libmlr/OWNERS b/libmlr/OWNERS new file mode 100644 index 00000000000..76e34e72c9d --- /dev/null +++ b/libmlr/OWNERS @@ -0,0 +1,2 @@ +lesters +tmartins diff --git a/libmlr/bin/xml2cpp b/libmlr/bin/xml2cpp new file mode 100755 index 00000000000..6367d745f78 --- /dev/null +++ b/libmlr/bin/xml2cpp @@ -0,0 +1,72 @@ +#!/bin/bash +# Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +# BEGIN environment bootstrap section +# Do not edit between here and END as this section should stay identical in all scripts + +findpath () { + myname=${0} + mypath=${myname%/*} + myname=${myname##*/} + if [ "$mypath" ] && [ -d "$mypath" ]; then + return + fi + mypath=$(pwd) + if [ -f "${mypath}/${myname}" ]; then + return + fi + echo "FATAL: Could not figure out the path where $myname lives from $0" + exit 1 +} + +COMMON_ENV=libexec/vespa/common-env.sh + +source_common_env () { + if [ "$VESPA_HOME" ] && [ -d "$VESPA_HOME" ]; then + # ensure it ends with "/" : + VESPA_HOME=${VESPA_HOME%/}/ + export VESPA_HOME + common_env=$VESPA_HOME/$COMMON_ENV + if [ -f "$common_env" ]; then + . $common_env + return + fi + fi + return 1 +} + +findroot () { + source_common_env && return + if [ "$VESPA_HOME" ]; then + echo "FATAL: bad VESPA_HOME value '$VESPA_HOME'" + exit 1 + fi + if [ "$ROOT" ] && [ -d "$ROOT" ]; then + VESPA_HOME="$ROOT" + source_common_env && return + fi + findpath + while [ "$mypath" ]; do + VESPA_HOME=${mypath} + source_common_env && return + mypath=${mypath%/*} + done + echo "FATAL: missing VESPA_HOME environment variable" + echo "Could not locate $COMMON_ENV anywhere" + exit 1 +} + +findroot + +# END environment bootstrap section + +if [ $# -lt 2 ]; then + echo "USAGE $0 <template> <mlrxxx.xml>"; + exit; +fi + +JAVA=java +CLASSPATH=$VESPA_HOME/lib/jars/xml2cpp.jar + +echo ${JAVA} -Xms64m -Xmx256m -cp ${CLASSPATH} com.yahoo.yst.libmlr.converter.DecisionTreeXmlToCpp -h $1 -i $2 +${JAVA} -Xms64m -Xmx256m -cp ${CLASSPATH} com.yahoo.yst.libmlr.converter.DecisionTreeXmlToCpp -h $1 -i $2 diff --git a/libmlr/pom.xml b/libmlr/pom.xml new file mode 100644 index 00000000000..53ac04ffcc7 --- /dev/null +++ b/libmlr/pom.xml @@ -0,0 +1,53 @@ +<?xml version="1.0"?> +<!-- Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> + <modelVersion>4.0.0</modelVersion> + <groupId>com.yahoo.vespa</groupId> + <artifactId>xml2cpp</artifactId> + <packaging>jar</packaging> + <version>1.0.0-SNAPSHOT</version> + <name>xml2cpp</name> + <description>Fork of xml2cppConverver with support for SS3 models.</description> + <dependencies> + </dependencies> + <build> + <finalName>${project.artifactId}</finalName> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-deploy-plugin</artifactId> + <version>2.5</version> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-compiler-plugin</artifactId> + <version>3.1</version> + <configuration> + <compilerArgs> + <arg>-Xlint:all</arg> + <arg>-Xlint:-serial</arg> + <arg>-Werror</arg> + </compilerArgs> + <showWarnings>true</showWarnings> + <source>1.8</source> + <target>1.8</target> + <showDeprecation>true</showDeprecation> + <showWarnings>true</showWarnings> + <optimize>true</optimize> + </configuration> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-install-plugin</artifactId> + <version>2.3.1</version> + <configuration> + <updateReleaseInfo>true</updateReleaseInfo> + </configuration> + </plugin> + </plugins> + </build> + <properties> + <test.hide>true</test.hide> + <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> + </properties> +</project> diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/DecisionTreeXmlToCpp.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/DecisionTreeXmlToCpp.java new file mode 100644 index 00000000000..1dc6daeec41 --- /dev/null +++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/DecisionTreeXmlToCpp.java @@ -0,0 +1,539 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.yst.libmlr.converter; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.io.PrintWriter; +import java.text.MessageFormat; +import java.util.Date; +import java.util.HashSet; + +import com.yahoo.yst.libmlr.converter.entity.EarlyExit; +import com.yahoo.yst.libmlr.converter.entity.Epilog; +import com.yahoo.yst.libmlr.converter.entity.FuncNormalize; +import com.yahoo.yst.libmlr.converter.entity.FuncPolytransform; +import com.yahoo.yst.libmlr.converter.entity.Function; +import com.yahoo.yst.libmlr.converter.entity.InternalNode; +import com.yahoo.yst.libmlr.converter.entity.ResponseNode; +import com.yahoo.yst.libmlr.converter.entity.Tree; +import com.yahoo.yst.libmlr.converter.entity.TreeNode; +import com.yahoo.yst.libmlr.converter.entity.TreeNodeVisitor; +import com.yahoo.yst.libmlr.converter.entity.TreenetFunction; +import com.yahoo.yst.libmlr.converter.parser.DecisionTreeXmlException; +import com.yahoo.yst.libmlr.converter.parser.MlrXmlParser; + +/** + * This class generates C++ from an MLR Decision Tree File + * + * @author allenwei + * + */ +public class DecisionTreeXmlToCpp { + + private static final String INDENT_UNIT = " "; + + private TreenetFunction tnFunc; + private String strCppFile; + private PrintWriter fpCpp; + + private int nodeIdx; + private int gIndentLevel; // global indent level + + public void setTnFunc(TreenetFunction tnFunc) { + this.tnFunc = tnFunc; + } + + public DecisionTreeXmlToCpp(String file) { + strCppFile = file; + try { + fpCpp = new PrintWriter( + new BufferedWriter( + new FileWriter(strCppFile))); + } catch (IOException ioex) { + System.out.println("Cannot open " + strCppFile + " for write"); + } + } + + /** + * Generates C++ code. + */ + public void genCode(String strHeaderFile) { + genCodeHeader(strHeaderFile); + gIndentLevel = 0; + setNodeIndex(); + genCodeDefs(); + genCodeFunc(); + } + + private void genCodeHeader(String strHeaderFile) { + String fmt = getFormatString(strHeaderFile); + String fileName = "mlr" + tnFunc.getFunctionId() + ".c"; + int nTrees = tnFunc.getNumberOfTrees(); + int nLeaves = tnFunc.getTree(0).getNumInternalNodes() + 1; + String header = MessageFormat.format(fmt, fileName, new Date(), + Integer.toString(nTrees), + Integer.toString(nLeaves)); + + gIndentLevel = 0; + printLn(0, header); + } + + private void setNodeIndex() { + + // set node id for each tree + int n = tnFunc.getNumberOfTrees(); + + SetNodeIndexVisitor nodeVisitor = new SetNodeIndexVisitor(); + SetLeafIndexVisitor leafVisitor = new SetLeafIndexVisitor(); + for (int i = 0; i < n; i++) { + nodeIdx = 0; + Tree tree = tnFunc.getTree(i); + traverseTree(tree.getRoot(), nodeVisitor); + traverseTree(tree.getRoot(), leafVisitor); + } + } + + private void genCodeDefs() { + printLn(0, "#define TOTAL_TREES " + tnFunc.getNumberOfTrees()); + printLn(); + + // const def for internal node labels + //genCodeTraverseTrees(0, null, new PrintNodeLabelDefVisitor(), null); + + // const def for leaf node labels + //genCodeTraverseTrees(0, null, new PrintLeafLabelDefVisitor(), null); + + genCodeNamespaceDefs(); + + // array init for internal nodes + genCodeTraverseTrees(1, "static const TreeNode nodes[] = {", + new PrintNodeInitVisitor(), "};"); + + // array init for leaf nodes + genCodeTraverseTrees(1, "static const double leaves[] = {", + new PrintLeafInitVisitor(), "};"); + + // array of tree size (number of internal nodes) + genCodeTreeSizeArrayInit(1); + + genCodeEarlyExits(1); + + } + + private void genCodeNamespaceDefs() { + printLn(0, "namespace " + tnFunc.getNameSpace() + " {"); + printLn(); + printLn(0, "enum Feature {"); + for (String f : tnFunc.getFeatureSet()) { + printLn(1, f + ","); + } + printLn(1, "NUMBER_FEATURES"); + printLn(0, "}; /* enum */"); // end enum + printLn(); + printLn(0, "} /* namespace */"); // end namespace + printLn(); + } + + private void genCodeFunc() { + + // function definition + printLn(0, "double"); + printLn(0, tnFunc.getFunctionName() + "(MlrScoreReq& msr) {"); + printLn(); + + genFeatureArrayDecl(1); + + // call traverseAll() + printLn(1, "msr.traverseAll(nodes, leaves, fValue, TOTAL_TREES, numNodes, meExits);"); + printLn(); + + genCodeEpilog(1); + + printLn(1, "return msr.getScore();"); + printLn(0, "}"); // end function + + fpCpp.close(); + } + + private void genFeatureArrayDecl(int indentInc) { + String ns = tnFunc.getNameSpace(); + + printLn(indentInc, "double fValue[" + ns + "::NUMBER_FEATURES];"); + printLn(); + + // FNTM: Distinguished values + //printInd(1, "double FNTM = fValue[" + ns + "::FNTM] = msr.getFeature(rf::FNTM);"); + //printLn(); + + HashSet<String> fSet = tnFunc.getFeatureSet(); + //fSet.remove("FNTM"); + + // initialization of features + for (String f : fSet) { + printLn(indentInc, "fValue[" + ns + "::" + f + "] = msr.getFeature(rf::" + f + ");"); + } + + printLn(); + } + + /** + * Prints code by iterating over all trees and visiting each tree node with + * the TreeNodeVisitor. + * + * @param indentInc - + * indentation level of the first line + * @param start - + * code printed before iterations + * @param end - + * code printed after iterations + */ + private void genCodeTraverseTrees(int indentInc, String start, + TreeNodeVisitor visitor, String end) { + + if (start != null) + printLn(indentInc, start); + + gIndentLevel += (indentInc + 1); + int n = tnFunc.getNumberOfTrees(); + for (int i = 0; i < n; i++) { + Tree t = tnFunc.getTree(i); + printLn("// " + t.getId() + " " + t.getComment()); + + traverseTree(t.getRoot(), visitor); + printLn(); + } + gIndentLevel -= (indentInc + 1); + + if (end != null) + printLn(indentInc, end); + + printLn(); + } + + private void genCodeTreeSizeArrayInit(int indentInc) { + String strDef = "static const int numNodes[" + + tnFunc.getNumberOfTrees() + "] = {"; + printLn(indentInc, strDef); + + int n = tnFunc.getNumberOfTrees(); + for (int i = 0; i < n; i++) { + String msg = tnFunc.getTree(i).getNumInternalNodes() + ", // " + i; + printLn(indentInc + 1, msg); + } + + printLn(indentInc, "};"); + printLn(); + } + + private void genCodeEarlyExits(int indentInc) { + printLn(indentInc, "static const MlrEarlyExit meExits[] = {"); + + int n = tnFunc.getNumEarlyExits(); + for (int i = 0; i < n; i++) { + EarlyExit eex = tnFunc.getEarlyExit(i); + String strEarlyExit = "{" + eex.getTreeId() + ", " + + "decisiontree::OP_" + eex.getOperator().getId().toUpperCase() + ", " + + eex.getValue() + + "},"; + printLn(indentInc + 1, strEarlyExit); + } + + // always generate a sentinel element for terminal condition + String sentinel = + "{" + tnFunc.getNumberOfTrees() + + ", decisiontree::OP_NONE, 0.0}"; + printLn(indentInc + 1, sentinel); + printLn(indentInc, "};"); + printLn(); + } + + /** + * Currently only generate code for normalize() + */ + private void genCodeEpilog(int indentInc) { + Epilog epilog = tnFunc.getEpilog(); + if (epilog == null) + return; + + Function func = epilog.getFunction(); + + if (func instanceof FuncNormalize) { + FuncNormalize funcNorm = (FuncNormalize) func; + + if (funcNorm.getInvertMethod() != FuncNormalize.INV_NONE) { + genCodeInversion(indentInc, funcNorm); + } + + if (funcNorm.isGenNormalize()) { + genCodeNormalize(indentInc, funcNorm); + } + } else if (func instanceof FuncPolytransform) { + FuncPolytransform funcPolytransform = (FuncPolytransform) func; + genCodePolytransform(indentInc, funcPolytransform); + } else { + throw new MlrCodeGenException("Unknown <epilogue> function: " + func.getClass().getName()); + } + } + + private void genCodeInversion(int indentInc, FuncNormalize funcNorm) { + if (funcNorm.getInvertMethod() == FuncNormalize.INV_INVERSION) { + printLn(indentInc, "msr.invert(" + funcNorm.getInvertedFrom() + ");"); + printLn(); + } else if (funcNorm.getInvertMethod() == FuncNormalize.INV_NEGATION) { + printLn(indentInc, "msr.negate();"); + printLn(); + } + } + + private void genCodeNormalize(int indentInc, FuncNormalize funcNorm) { + StringBuilder sb = new StringBuilder(); + + printLn(indentInc, "msr.normalize("); + + sb.append(funcNorm.getMean0()).append(", ") + .append(funcNorm.getSd0()).append(", ") + .append(funcNorm.getA0()).append(", ") + .append(funcNorm.getB0()).append(", "); + printLn(indentInc + 1, sb.toString()); + + sb.setLength(0); + sb.append(funcNorm.getMean1()).append(", ") + .append(funcNorm.getSd1()).append(", ") + .append(funcNorm.getA1()).append(", ") + .append(funcNorm.getB1()).append(", "); + printLn(indentInc + 1, sb.toString()); + + sb.setLength(0); + sb.append(funcNorm.getMean2()).append(", ") + .append(funcNorm.getSd2()).append(", ") + .append(funcNorm.getA2()).append(", ") + .append(funcNorm.getB2()).append(", "); + printLn(indentInc + 1, sb.toString()); + + sb.setLength(0); + sb.append(funcNorm.getMean3()).append(", ") + .append(funcNorm.getSd3()).append(", ") + .append(funcNorm.getA3()).append(", ") + .append(funcNorm.getB3()); + printLn(indentInc + 1, sb.toString()); + + printLn(indentInc, ");"); + printLn(); + + } + + private void genCodePolytransform(int indentInc, FuncPolytransform funcPoly) { + StringBuilder sb = new StringBuilder(); + sb.append("msr.polytransform("); + sb.append(funcPoly.getA0()).append(", ") + .append(funcPoly.getA1()).append(", ") + .append(funcPoly.getA2()).append(", ") + .append(funcPoly.getA3()).append(");"); + printLn(indentInc, sb.toString()); + printLn(); + + } + + // Utilities + + private String getFormatString(String strFmtFile) { + try { + BufferedReader fp = new BufferedReader( + new FileReader(strFmtFile)); + StringBuilder sb = new StringBuilder(); + String line = null; + while ((line = fp.readLine()) != null) { + sb.append(line).append("\n"); + } + + String fmt = sb.toString(); + fp.close(); + + return fmt; + + } catch (FileNotFoundException e) { + throw new MlrCodeGenException(strFmtFile, e); + } catch (IOException ioe) { + throw new MlrCodeGenException("reading file " + strFmtFile, + ioe); + } + } + + private void traverseTree(TreeNode node, TreeNodeVisitor v) { + v.visit(node); + if (node instanceof InternalNode) { + InternalNode dcNode = (InternalNode) node; + traverseTree(dcNode.getLeftNode(), v); + traverseTree(dcNode.getRightNode(), v); + } + } + + private void printAppend(String str) { + fpCpp.print(str); + } + + private void printLn(int inc, String str) { + int indent = gIndentLevel + inc; + for (int i = 0; i < indent; i++) { + fpCpp.print(INDENT_UNIT); + } + fpCpp.println(str); + } + + private void printLn(String str) { + printLn(0, str); + } + + private void printLn() { + fpCpp.println(); + } + + /** + * subclasses of TreeNodeVisitor + */ + + private class SetNodeIndexVisitor implements TreeNodeVisitor { + + public void visit(TreeNode node) { + if (node instanceof InternalNode) { + node.setIndex(nodeIdx++); + } + } + } + + private class SetLeafIndexVisitor implements TreeNodeVisitor { + + public void visit(TreeNode node) { + if (node instanceof ResponseNode) { + node.setIndex(nodeIdx++); + } + } + } + + /* + private class PrintNodeLabelDefVisitor implements TreeNodeVisitor { + + public void visit(TreeNode node) { + if (node instanceof InternalNode) { + printInd(0, "#define " + node.label + " " + node.nodeId); + } + } + } + + private class PrintLeafLabelDefVisitor implements TreeNodeVisitor { + + public void visit(TreeNode node) { + if (node instanceof LeafNode) { + printInd(0, "#define " + node.label + " " + node.nodeId); + } + } + } + */ + + private class PrintNodeInitVisitor implements TreeNodeVisitor { + + public void visit(TreeNode treeNode) { + if (treeNode instanceof InternalNode) { + InternalNode node = (InternalNode) treeNode; + int leftNodeIndex = node.getLeftNode().getIndex(); + int rightNodeIndex = node.getRightNode().getIndex(); + + StringBuilder sb = new StringBuilder(); + sb.append(node.getIndex() + " " + node.getId()); + sb.append(" ").append(node.getComment()); + + String str = "{ " + tnFunc.getNameSpace() + "::" + node.getFeature() + ", " + + node.getValue() + ", " + leftNodeIndex + ", " + + rightNodeIndex + " }, // " + sb.toString(); + + printLn(str); + } + } + } + + private class PrintLeafInitVisitor implements TreeNodeVisitor { + + public void visit(TreeNode treeNode) { + if (treeNode instanceof ResponseNode) { + ResponseNode node = (ResponseNode) treeNode; + StringBuilder sb = new StringBuilder(); + sb.append(node.getIndex() + " " + node.getId()); + sb.append(" ").append(node.getComment()); + + String str = node.getResponse() + ", // " + sb.toString(); + printLn(str); + } + } + } + + public static void main(String[] args) { + String xmlFile = null; + String headerFile = null; + String cppFile = null; + + int i = 0; + boolean hasErrors = false; + while (i < args.length && !hasErrors) { + String arg = args[i++]; + if (arg.equals("-i")) { + if (i < args.length) + xmlFile = args[i++]; + else + hasErrors = true; + + } else if (arg.equals("-h")) { + if (i < args.length) + headerFile = args[i++]; + else + hasErrors = true; + + } else if (arg.equals("-o")) { + if (i < args.length) + cppFile = args[i++]; + else + hasErrors = true; + } + + } + + if (xmlFile == null || headerFile == null) + hasErrors = true; + + if (hasErrors) { + System.out.println("USAGE: java DecisionTreeXmlToCpp -i XML_file -h header_file [-o Cpp_file]"); + return; + } + + if (cppFile == null) { + if (xmlFile.endsWith(".xml")) { + int idx = xmlFile.lastIndexOf('.'); + cppFile = xmlFile.substring(0, idx+1) + "c"; + } else { + cppFile = xmlFile + ".c"; + } + } + + File fpCpp = new File(cppFile); + if (fpCpp.exists()) { + System.out.println(cppFile + " exits. Please rename and run again."); + return; + } + + try { + MlrXmlParser parser = new MlrXmlParser(); + DecisionTreeXmlToCpp toCpp = new DecisionTreeXmlToCpp(cppFile); + + toCpp.setTnFunc((TreenetFunction) parser.parseXmlFile(xmlFile)); + toCpp.genCode(headerFile); + + } catch (DecisionTreeXmlException tnex) { + tnex.printStackTrace(); + } + + } +} diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/MlrCodeGenException.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/MlrCodeGenException.java new file mode 100644 index 00000000000..9a30f82f78b --- /dev/null +++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/MlrCodeGenException.java @@ -0,0 +1,18 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.yst.libmlr.converter; + +public class MlrCodeGenException extends RuntimeException { + + public MlrCodeGenException() { + super(); + } + + public MlrCodeGenException(String msg) { + super(msg); + } + + public MlrCodeGenException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/XmlUtils.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/XmlUtils.java new file mode 100644 index 00000000000..022f4d31a4e --- /dev/null +++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/XmlUtils.java @@ -0,0 +1,65 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.yst.libmlr.converter; + +import java.util.ArrayList; + +import org.w3c.dom.Element; +import org.w3c.dom.Node; + +public class XmlUtils { + + public static Element getFirstChildElement(Node parent) { + if (parent == null) + return null; + + Node nd = parent.getFirstChild(); + while (nd != null) { + //System.out.println("type: " + nd.getNodeType() + " name: " + nd.getNodeName()); + if (nd.getNodeType() == Node.ELEMENT_NODE) { + return (Element)nd; + } + nd = nd.getNextSibling(); + } + + return null; + } + + public static Element getFirstChildElementByName(Node parent, String childName) { + if (parent == null) + return null; + + Node nd = parent.getFirstChild(); + while (nd != null) { + //System.out.println("type: " + nd.getNodeType() + " name: " + nd.getNodeName()); + if (nd.getNodeType() == Node.ELEMENT_NODE + && nd.getNodeName().equals(childName)) { + return (Element)nd; + } + nd = nd.getNextSibling(); + } + + return null; + } + + public static ArrayList<Element> getChildrenByName(Node parent, String childName) { + if (parent == null) + return null; + + ArrayList<Element> list = new ArrayList<Element>(); + Node nd = parent.getFirstChild(); + while (nd != null) { + //System.out.println("type: " + nd.getNodeType() + " name: " + nd.getNodeName()); + if (nd.getNodeType() == Node.ELEMENT_NODE + && nd.getNodeName().equals(childName)) { + list.add((Element)nd); + } + nd = nd.getNextSibling(); + } + + if (list.size() == 0) + return null; + else + return list; + } + +} diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/EarlyExit.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/EarlyExit.java new file mode 100644 index 00000000000..decc5e73985 --- /dev/null +++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/EarlyExit.java @@ -0,0 +1,28 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.yst.libmlr.converter.entity; + + +public class EarlyExit { + protected int treeId; + protected Operator operator; + protected String value; + + public EarlyExit(int tid, Operator op, String val) { + treeId = tid; + operator = op; + value = val; + } + + public int getTreeId() { + return treeId; + } + + public String getValue() { + return value; + } + + public Operator getOperator() { + return operator; + } + +} diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Epilog.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Epilog.java new file mode 100644 index 00000000000..a2ea7835869 --- /dev/null +++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Epilog.java @@ -0,0 +1,16 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.yst.libmlr.converter.entity; + +public class Epilog { + + protected Function function; + + public Function getFunction() { + return function; + } + + public void setFunction(Function f) { + this.function = f; + } + +} diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/FuncNormalize.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/FuncNormalize.java new file mode 100644 index 00000000000..861c99dc3e3 --- /dev/null +++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/FuncNormalize.java @@ -0,0 +1,211 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.yst.libmlr.converter.entity; + +public class FuncNormalize implements Function { + + public static int INV_NONE = 0; + public static int INV_INVERSION = 1; + public static int INV_NEGATION = 2; + + private int invertMethod = INV_NONE; + private String invertedFrom; + + /* + * The following parameters are type String to preserve precision. + */ + private boolean pGenNormalize; + protected String mean0; + protected String mean1; + protected String mean2; + protected String mean3; + protected String sd0; + protected String sd1; + protected String sd2; + protected String sd3; + protected String a0; + protected String a1; + protected String a2; + protected String a3; + protected String b0; + protected String b1; + protected String b2; + protected String b3; + + public FuncNormalize() {} + + public int getInvertMethod() { + return invertMethod; + } + + public void setInvertMethod(int inv) { + this.invertMethod = inv; + } + + public String getInvertedFrom() { + return invertedFrom; + } + + public void setInvertedFrom(String invertedFrom) { + this.invertedFrom = invertedFrom; + } + + public boolean isGenNormalize() { + return pGenNormalize; + } + + public void setpDoNormalize(boolean doNormalize) { + this.pGenNormalize = doNormalize; + } + + public String getMean0() { + return mean0; + } + + public void setMean0(String mean0) { + this.mean0 = mean0; + } + + public String getMean1() { + return mean1; + } + + public void setMean1(String mean1) { + this.mean1 = mean1; + } + + public String getMean2() { + return mean2; + } + + public void setMean2(String mean2) { + this.mean2 = mean2; + } + + public String getMean3() { + return mean3; + } + + public void setMean3(String mean3) { + this.mean3 = mean3; + } + + public String getSd0() { + return sd0; + } + + public void setSd0(String sd0) { + this.sd0 = sd0; + } + + public String getSd1() { + return sd1; + } + + public void setSd1(String sd1) { + this.sd1 = sd1; + } + + public String getSd2() { + return sd2; + } + + public void setSd2(String sd2) { + this.sd2 = sd2; + } + + public String getSd3() { + return sd3; + } + + public void setSd3(String sd3) { + this.sd3 = sd3; + } + + public String getA0() { + return a0; + } + + public void setA0(String a0) { + this.a0 = a0; + } + + public String getA1() { + return a1; + } + + public void setA1(String a1) { + this.a1 = a1; + } + + public String getA2() { + return a2; + } + + public void setA2(String a2) { + this.a2 = a2; + } + + public String getA3() { + return a3; + } + + public void setA3(String a3) { + this.a3 = a3; + } + + public String getB0() { + return b0; + } + + public void setB0(String b0) { + this.b0 = b0; + } + + public String getB1() { + return b1; + } + + public void setB1(String b1) { + this.b1 = b1; + } + + public String getB2() { + return b2; + } + + public void setB2(String b2) { + this.b2 = b2; + } + + public String getB3() { + return b3; + } + + public void setB3(String b3) { + this.b3 = b3; + } + + public boolean validateParams() { + if (mean0 != null) { + if (mean1 == null || mean2 == null || mean3 == null || + sd0 == null || sd1 == null || sd2 == null || sd3 == null || + a0 == null || a1 == null || a2 == null || a3 == null || + b0 == null || b1 == null || b2 == null || b3 == null) { + return false; + } else { + pGenNormalize = true; + } + } else { // mean0 == null + if (mean1 != null || mean2 != null || mean3 != null || + sd0 != null || sd1 != null || sd2 != null || sd3 != null || + a0 != null || a1 != null || a2 != null || a3 != null || + b0 != null || b1 != null || b2 != null || b3 != null) { + return false; + } else { + pGenNormalize = false; + } + } + + return true; + } +} diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/FuncPolytransform.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/FuncPolytransform.java new file mode 100644 index 00000000000..9925d60cd93 --- /dev/null +++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/FuncPolytransform.java @@ -0,0 +1,50 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.yst.libmlr.converter.entity; + +public class FuncPolytransform implements Function { + /* + * The following parameters are type String to preserve precision. + */ + protected String a0; + protected String a1; + protected String a2; + protected String a3; + + public FuncPolytransform() {} + + public String getA0() { + return a0; + } + + public void setA0(String a0) { + this.a0 = a0; + } + + public String getA1() { + return a1; + } + + public void setA1(String a1) { + this.a1 = a1; + } + + public String getA2() { + return a2; + } + + public void setA2(String a2) { + this.a2 = a2; + } + + public String getA3() { + return a3; + } + + public void setA3(String a3) { + this.a3 = a3; + } + + public boolean validateParams() { + return (a0 != null && a1 != null && a2 != null && a3 != null); + } +} diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Function.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Function.java new file mode 100644 index 00000000000..e2649652f52 --- /dev/null +++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Function.java @@ -0,0 +1,8 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.yst.libmlr.converter.entity; + +public interface Function { + + public boolean validateParams(); + +} diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/InternalNode.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/InternalNode.java new file mode 100644 index 00000000000..a63361d1d20 --- /dev/null +++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/InternalNode.java @@ -0,0 +1,61 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.yst.libmlr.converter.entity; + + +public class InternalNode extends TreeNode { + + private String feature; + private String op; + private String value; + private TreeNode left; // true + private TreeNode right; // false + + public InternalNode(String i, String c, String f, String v, TreeNode lf, TreeNode rt) { + super(i, c); + feature = f; + value = v; + left = lf; + right = rt; + } + + public String getFeature() { + return feature; + } + + public void setFeature(String feature) { + this.feature = feature; + } + + public String getOp() { + return op; + } + + public void setOp(String op) { + this.op = op; + } + + public String getValue() { + return value; + } + + public void setValue(String value) { + this.value = value; + } + + public TreeNode getLeftNode() { + return left; + } + + public void setLeftNode(TreeNode left) { + this.left = left; + } + + public TreeNode getRightNode() { + return right; + } + + public void setRightNode(TreeNode right) { + this.right = right; + } + +} diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/MlrFunction.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/MlrFunction.java new file mode 100644 index 00000000000..760274dba0e --- /dev/null +++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/MlrFunction.java @@ -0,0 +1,31 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.yst.libmlr.converter.entity; + + +public abstract class MlrFunction { + protected String functionName; + protected String funcId; // numeric function id + protected String featureDefFile; + protected Epilog epilog; + + public String getFunctionName() { + return functionName; + } + + public String getFunctionId() { + return funcId; + } + + public String getFeatureDefFile() { + return featureDefFile; + } + + public Epilog getEpilog() { + return epilog; + } + + public void setEpilog(Epilog epilog) { + this.epilog = epilog; + } + +} diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Operator.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Operator.java new file mode 100644 index 00000000000..9052ee8ecc3 --- /dev/null +++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Operator.java @@ -0,0 +1,36 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.yst.libmlr.converter.entity; + +public enum Operator { + EQ("eq"), + NEQ("neq"), + GT("gt"), + GEQ("geq"), + LT("lt"), + LEQ("leq"); + + private final String id; + + Operator(String id) { + this.id = id; + } + + public static Operator parse(String str) { + for (Operator op : Operator.values()) { + if (op.id.equals(str)) + return op; + } + throw new IllegalArgumentException(); + } + + public String getId() { + return id; + } + + public static void main(String[] args) { + Operator op = Operator.parse("gt"); + System.out.println("operator.toString = " + op.toString()); + System.out.println("operator = " + op.getId()); + } + +} diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/ResponseNode.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/ResponseNode.java new file mode 100644 index 00000000000..b2d49023458 --- /dev/null +++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/ResponseNode.java @@ -0,0 +1,21 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.yst.libmlr.converter.entity; + +public class ResponseNode extends TreeNode { + + private double response; + + public ResponseNode(String i, String c, double r) { + super(i, c); + response = r; + } + + public double getResponse() { + return response; + } + + public void setResponse(double response) { + this.response = response; + } + +} diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Tree.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Tree.java new file mode 100644 index 00000000000..ba9fb278cfc --- /dev/null +++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/Tree.java @@ -0,0 +1,57 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.yst.libmlr.converter.entity; + + +public class Tree { + + private String id; + private String comment; + + private InternalNode root; + private int nInternalNodes; // number of internal nodes + + + public Tree() {} + + public Tree(String id, String comment) { + this.id = id; + this.comment = comment; + } + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getComment() { + return (comment == null ? "" : comment); + } + + public void setComment(String comment) { + this.comment = comment; + } + + public InternalNode getRoot() { + return root; + } + + public void setRoot(InternalNode root) { + this.root = root; + } + + public int getNumInternalNodes() { + return nInternalNodes; + } + + public void incrInteralNodes() { + nInternalNodes++; + } + + public void setNumInternalNodes(int n) { + nInternalNodes = n; + } + +} diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreeNode.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreeNode.java new file mode 100644 index 00000000000..158d4f8f788 --- /dev/null +++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreeNode.java @@ -0,0 +1,39 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.yst.libmlr.converter.entity; + +public class TreeNode { + + private String id; + private String comment; + private int idx; + + public TreeNode(String i, String c) { + id = i; + comment = c; + } + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getComment() { + return (comment == null ? "" : comment); + } + + public void setComment(String comment) { + this.comment = comment; + } + + public int getIndex() { + return idx; + } + + public void setIndex(int idx) { + this.idx = idx; + } + +} diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreeNodeVisitor.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreeNodeVisitor.java new file mode 100644 index 00000000000..5ced96fb04c --- /dev/null +++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreeNodeVisitor.java @@ -0,0 +1,8 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.yst.libmlr.converter.entity; + +public interface TreeNodeVisitor { + + public void visit(TreeNode node); + +} diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreenetFunction.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreenetFunction.java new file mode 100644 index 00000000000..2f9d4203d50 --- /dev/null +++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/entity/TreenetFunction.java @@ -0,0 +1,103 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.yst.libmlr.converter.entity; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.w3c.dom.Attr; +import org.w3c.dom.Document; +import org.w3c.dom.Element; +import org.w3c.dom.NamedNodeMap; +import org.w3c.dom.Node; +import org.w3c.dom.NodeList; + +import com.yahoo.yst.libmlr.converter.parser.DecisionTreeXmlException; + +public class TreenetFunction extends MlrFunction { + + private String ns; // namespace + private ArrayList<Tree> treeArylst; + private HashSet<String> featureSet; + private HashSet<String> labelSet; + protected ArrayList<EarlyExit> earlyExitArylst; + + + public TreenetFunction() { + treeArylst = new ArrayList<Tree>(500); + featureSet = new HashSet<String>(); + labelSet = new HashSet<String>(); + earlyExitArylst = new ArrayList<EarlyExit>(5); + } + + public void setFunctionName(String id) { + funcId = id; + functionName = "mlr" + id; + /* + Pattern p = Pattern.compile("[^\\d]+(\\d+)\\w*"); + Matcher m = p.matcher(functionName); + if (!m.matches()) + throw new IllegalArgumentException("not a valid functionName"); + + funcId = m.group(1); + */ + ns = "mlr" + funcId + "ns"; + } + + public String getNameSpace() { + return ns; + } + + public int getNumberOfTrees() { + return treeArylst.size(); + } + + public Tree getTree(int i) { + return treeArylst.get(i); + } + + public void setTree(Tree t) { + treeArylst.add(t); + } + + public HashSet<String> getFeatureSet() { + return featureSet; + } + + public HashSet<String> getLabelSet() { + return labelSet; + } + + public void addFeature(String f) { + featureSet.add(f); + } + + public void addLabel(String lbl) { + if (labelSet.contains(lbl)) + throw new DecisionTreeXmlException("Label " + lbl + " existed."); + labelSet.add(lbl); + } + + public void removeLabelSet() { + labelSet = null; + } + + public void getAllFeatures() { + for (String f: featureSet) { + System.out.println(f); + } + } + + public void addEarlyExit(EarlyExit earx) { + earlyExitArylst.add(earx); + } + + public int getNumEarlyExits() { + return earlyExitArylst.size(); + } + + public EarlyExit getEarlyExit(int i) { + return earlyExitArylst.get(i); + } +} diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/parser/DecisionTreeXmlException.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/parser/DecisionTreeXmlException.java new file mode 100644 index 00000000000..3f792ef991f --- /dev/null +++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/parser/DecisionTreeXmlException.java @@ -0,0 +1,17 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.yst.libmlr.converter.parser; + +public class DecisionTreeXmlException extends RuntimeException { + + public DecisionTreeXmlException() { + super(); + } + + public DecisionTreeXmlException(String msg) { + super(msg); + } + + public DecisionTreeXmlException(String msg, Throwable cause) { + super(msg, cause); + } +} diff --git a/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/parser/MlrXmlParser.java b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/parser/MlrXmlParser.java new file mode 100644 index 00000000000..a689aefcb98 --- /dev/null +++ b/libmlr/src/main/java/com/yahoo/yst/libmlr/converter/parser/MlrXmlParser.java @@ -0,0 +1,435 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.yst.libmlr.converter.parser; + +import com.yahoo.yst.libmlr.converter.XmlUtils; +import com.yahoo.yst.libmlr.converter.entity.*; +import org.w3c.dom.*; + +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.parsers.ParserConfigurationException; +import java.io.File; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.logging.Logger; + +/** + * Parses Treenet output V5 into Abstract Treenet XML File format. + * + * @author allenwei + * + */ +public class MlrXmlParser { + + private static Logger logger = Logger.getLogger("com.yahoo.yst.libmlrutil.TnXmlParser"); + private static final String errNormAttr = "<Normalize>: All or none of attributes mean0-3, sd0-3, a0-3, b0-3 are required"; + private static final String errPolyAttr = "<Normalize>: All or none of attributes a0-3 are required"; + + private HashSet<String> treeIdSet = new HashSet<String>(500); + private HashSet<String> nodeIdSet = new HashSet<String>(10000); + private HashSet<String> respIdSet = new HashSet<String>(10000); + + public MlrFunction parseXmlFile(String fileName) throws DecisionTreeXmlException { + + File file = new File(fileName); + if (!file.exists()) { + String errMsg = fileName + " does not exist."; + logErrors(errMsg); + throw new DecisionTreeXmlException(errMsg); + } + + DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); + DocumentBuilder docBuilder = null; + + try { + docBuilder = dbf.newDocumentBuilder(); + Document doc = docBuilder.parse(file); + Element eltMlrFnc = doc.getDocumentElement(); + if (!eltMlrFnc.getTagName().equals("MlrFunction")) { + String errMsg = "The top element must be <MlrFunction>"; + logErrors(errMsg); + throw new DecisionTreeXmlException(errMsg); + } + + return parseMlrFunction(eltMlrFnc); + + //System.out.println("features: " + tnFunc.getFeatureSet().size()); + //System.out.println("labels: " + tnFunc.getLabelSet().size()); + + } catch (ParserConfigurationException pe) { + String errMsg = "Cannot construct XML DocumentBuilder"; + logErrors(pe, errMsg); + throw new DecisionTreeXmlException(errMsg, pe); + + } catch (DecisionTreeXmlException te) { + throw te; + + } catch (Exception ex) { + String errMsg ="Errors found parsing XML"; + logErrors(errMsg); + ex.printStackTrace(); + throw new DecisionTreeXmlException(errMsg, ex); + } + } + + private MlrFunction parseMlrFunction(Element eltMlrFnc) { + MlrFunction mlrFunc = null; + + Element eltDecisionTree = getFirstChildElementByName(eltMlrFnc, "DecisionTree", true); + TreenetFunction tnFunc = new TreenetFunction(); + String id = getAttribute(eltMlrFnc, "name", true); + try { + Integer.parseInt(id); + } catch (NumberFormatException nex) { + throw new DecisionTreeXmlException("name in <MlrFunction> should be an integer " + id , nex); + } + tnFunc.setFunctionName(id); + parseDecisionTree(eltDecisionTree, tnFunc); + + mlrFunc = tnFunc; + + if (mlrFunc != null) { + Element eltEpilog = getFirstChildElementByName(eltMlrFnc, "Epilogue", false); + if (eltEpilog != null) { + Epilog epilog = parseEpilog(eltEpilog); + mlrFunc.setEpilog(epilog); + } + } + + return mlrFunc; + } + + private void parseDecisionTree(Element eltDecisionTree, TreenetFunction tnFunc) { + parseForest(getFirstChildElementByName(eltDecisionTree, "Forest", true), tnFunc); + + Element eltEarlyExits = getFirstChildElementByName(eltDecisionTree, "EarlyExits", false); + if (eltEarlyExits != null) + parseEarlyExits(eltEarlyExits, tnFunc); + } + + private void parseForest(Element eltForest, TreenetFunction tnFunc) { + //String strTotal = eltForest.getAttribute("total"); + // tnFunc.setNumTrees(Integer.parseInt(eltForest.getAttribute("total"))); + + ArrayList<Element> nl = XmlUtils.getChildrenByName(eltForest, "Tree"); + int n = nl.size(); + if (n == 0) + throw new DecisionTreeXmlException("<Forest> should have at least one <Tree> element"); + + for (int i = 0; i < n; i++) { + parseTree(nl.get(i), tnFunc); + } + } + + private void parseTree(Element eltTree, TreenetFunction tnFunc) { + String comment = getAttribute(eltTree, "comment", false); + String id = getAttribute(eltTree, "id", true); + if (treeIdSet.contains(id)) + throw new DecisionTreeXmlException("Duplicate tree id " + id); + else + treeIdSet.add(id); + + Tree tr = new Tree(id, comment); + tnFunc.setTree(tr); + + // DEBUG + //System.out.println("tree " + id); + ArrayList<Element> list = XmlUtils.getChildrenByName(eltTree, "Node"); + if (list == null || list.size() != 1) + throw new DecisionTreeXmlException("<Tree> should have exactly one root <Node> element"); + + Element eltNode = list.get(0); + InternalNode root = parseInternalNode(eltNode, tnFunc, tr); + tr.setRoot(root); + } + + private TreeNode parseTreeNode(Element eltNode, TreenetFunction tnFunc, Tree tr) { + String tag = eltNode.getNodeName(); + if (tag.equals("Node")) + return parseInternalNode(eltNode, tnFunc, tr); + else if (tag.equals("Response")) + return parseResponse(eltNode, tnFunc); + else + throw new DecisionTreeXmlException("ERROR: unknown tag <" + tag + ">. Should never reach here."); + } + + private InternalNode parseInternalNode(Element eltNode, TreenetFunction tnFunc, Tree tr) { + tr.incrInteralNodes(); + + String id = getAttribute(eltNode, "id", true); + if (nodeIdSet.contains(id)) + throw new DecisionTreeXmlException("Duplicate Internal Node id " + id); + else + nodeIdSet.add(id); + + String comment = getAttribute(eltNode, "comment", false); + + String feature = getAttribute(eltNode, "feature", true); + tnFunc.addFeature(feature); + + String value = getAttribute(eltNode, "value", true); + try { + Double.parseDouble(value); + } catch (NumberFormatException nfex) { + String errMsg = "Node " + id + ": value not a number: " + value; + throw new DecisionTreeXmlException(errMsg, nfex); + } + + ArrayList<Node> childNodes = new ArrayList<Node>(5); + + NodeList nl = eltNode.getChildNodes(); + int n = nl.getLength(); + Node nd; + for (int i = 0; i < n; i++) { + nd = nl.item(i); + if (nd.getNodeType() == Node.ELEMENT_NODE) { + String tag = nd.getNodeName(); + if (tag.equals("Node") || tag.equals("Response")) { + childNodes.add(nd); + } + } + } + + int numChildNodes = childNodes.size(); + if (numChildNodes != 2) { + String strNode = "Node: id=" + id + " " + feature + " " + value; + String errMsgNodes = "ERROR: A <Node> element should have exactly 2 child nodes. A child node can be <Node> or <Response>. " + strNode; + throw new DecisionTreeXmlException(errMsgNodes); + } + + TreeNode left = parseTreeNode((Element)childNodes.get(0), tnFunc, tr); + TreeNode right = parseTreeNode((Element)childNodes.get(1), tnFunc, tr); + + return new InternalNode(id, comment, feature, value, left, right); + } + + private ResponseNode parseResponse(Element eltResponse, TreenetFunction tnFunc) { + String id = getAttribute(eltResponse, "id", true); + if (respIdSet.contains(id)) + throw new DecisionTreeXmlException("Duplicate Response Node id " + id); + else + respIdSet.add(id); + + String comment = getAttribute(eltResponse, "comment", false); + + String strValue = eltResponse.getAttribute("value"); + double value; + try { + value = Double.parseDouble(strValue); + } catch (NumberFormatException ne) { + throw new DecisionTreeXmlException("Response Node " + id + " does not contain a double value. value=" + strValue); + } + + return new ResponseNode(id, comment, value); + } + + private void parseEarlyExits(Element eltEarlyExits, TreenetFunction tnFunc) { + ArrayList<Element> nl = XmlUtils.getChildrenByName(eltEarlyExits, "Exit"); + if (nl != null) { + int n = nl.size(); + for (int i = 0; i < n; i++) { + parseExit(nl.get(i), tnFunc); + } + } + } + + private void parseExit(Element eltExit, TreenetFunction tnFunc) { + String attr = getAttribute(eltExit, "tree", true); + int tree; + try { + tree = Integer.parseInt(attr); + } catch (NumberFormatException ex) { + String errMsg = "Invalid value for attribute tree: " + attr; + throw new DecisionTreeXmlException(errMsg); + } + + String strValue = getAttribute(eltExit, "value", true); + try { + Double.parseDouble(attr); + } catch (NumberFormatException ex) { + String errMsg = "Invalid value for attribute value: " + attr; + throw new DecisionTreeXmlException(errMsg); + } + + attr = getAttribute(eltExit, "op", true); + Operator op; + try { + op = Operator.parse(attr); + } catch (IllegalArgumentException ex) { + String errMsg = "Invalid value for attribute op: " + attr; + throw new DecisionTreeXmlException(errMsg); + } + + tnFunc.addEarlyExit(new EarlyExit(tree, op, strValue)); + + } + + private Epilog parseEpilog(Element eltEpilog) { + Element eltOp = XmlUtils.getFirstChildElement(eltEpilog); + if (eltOp.getNodeName().equals("Normalize")) { + try { + return parseNormalize(eltOp); + } catch (DecisionTreeXmlException e) { + return null; + } + } else if (eltOp.getNodeName().equals("Polytransform")) { + return parsePolytransform(eltOp); + } + else { + return null; + } + } + + private Epilog parseNormalize(Element eltNorm) { + Epilog epilog = new Epilog(); + FuncNormalize func = new FuncNormalize(); + epilog.setFunction(func); + + String strIsInv = getBoolAttribute(eltNorm, "isInverted", false); + if (strIsInv != null && strIsInv.equals("true")) { + String strInvFrom = getDoubleAttribute(eltNorm, "invertedFrom", true); + func.setInvertMethod(FuncNormalize.INV_INVERSION); + func.setInvertedFrom(strInvFrom); + } + + String strIsNeg = getAttribute(eltNorm, "isNegated", false); + if (strIsNeg != null && strIsNeg.equals("true")) { + if (func.getInvertMethod() == FuncNormalize.INV_NONE) + func.setInvertMethod(FuncNormalize.INV_NEGATION); + else + throw new DecisionTreeXmlException("cannot have both isInverted and isNegated defined in element <Normalize>"); + } + + func.setMean0(getDoubleAttribute(eltNorm, "mean0", false)); + func.setMean1(getDoubleAttribute(eltNorm, "mean1", false)); + func.setMean2(getDoubleAttribute(eltNorm, "mean2", false)); + func.setMean3(getDoubleAttribute(eltNorm, "mean3", false)); + func.setSd0(getDoubleAttribute(eltNorm, "sd0", false)); + func.setSd1(getDoubleAttribute(eltNorm, "sd1", false)); + func.setSd2(getDoubleAttribute(eltNorm, "sd2", false)); + func.setSd3(getDoubleAttribute(eltNorm, "sd3", false)); + func.setA0(getDoubleAttribute(eltNorm, "a0", false)); + func.setA1(getDoubleAttribute(eltNorm, "a1", false)); + func.setA2(getDoubleAttribute(eltNorm, "a2", false)); + func.setA3(getDoubleAttribute(eltNorm, "a3", false)); + func.setB0(getDoubleAttribute(eltNorm, "b0", false)); + func.setB1(getDoubleAttribute(eltNorm, "b1", false)); + func.setB2(getDoubleAttribute(eltNorm, "b2", false)); + func.setB3(getDoubleAttribute(eltNorm, "b3", false)); + + if (!func.validateParams()) + throw new DecisionTreeXmlException(errNormAttr); + + return epilog; + } + + private Epilog parsePolytransform(Element eltOp) { + Epilog epilog = new Epilog(); + FuncPolytransform func = new FuncPolytransform(); + + func.setA0(getDoubleAttribute(eltOp, "a0", false)); + func.setA1(getDoubleAttribute(eltOp, "a1", false)); + func.setA2(getDoubleAttribute(eltOp, "a2", false)); + func.setA3(getDoubleAttribute(eltOp, "a3", false)); + + if (!func.validateParams()) + throw new DecisionTreeXmlException(errPolyAttr); + + epilog.setFunction(func); + return epilog; + } + + /** + * Checks if the attribute name exists. + * + * @param eltNorm - the element containing the attribute + * @param attr - attribute name + * @return true if the attribute exists; or false, otherwise. + */ + private boolean checkAttrExist(Element eltNorm, String attr) { + Attr attrNode = eltNorm.getAttributeNode(attr); + if (attrNode != null) + return true; + else + return false; + } + + /** + * Returns the value of attribute. + * + * @param elt + * @param attr + * @param reqd + * @return If the attribute exists, the value of the attribute is returned, otherwise null is returned. + */ + private String getAttribute(Element elt, String attr, boolean reqd) { + Attr attrNode = elt.getAttributeNode(attr); + String val = null; + if (attrNode != null) + val = elt.getAttribute(attr); + + if (reqd && (val == null || val.equals(""))) + throw new DecisionTreeXmlException(elt.getTagName() + ": missing required attribute " + attr); + return val; + } + + private String getBoolAttribute(Element elt, String attr, boolean reqd) { + String strVal = getAttribute(elt, attr, reqd); + + if (strVal == null || + ((strVal.equals("true") || strVal.equals("false")))) { + return strVal; + } else { + String errMsg = "Attribute " + attr + " in Element " + elt.getTagName() + " is not a valid boolean value: " + strVal; + throw new DecisionTreeXmlException(errMsg); + } + } + + private String getIntAttribute(Element elt, String attr, boolean reqd) { + String strVal = getAttribute(elt, attr, reqd); + try { + if (strVal != null) + Integer.parseInt(strVal); + return strVal; + } catch (NumberFormatException ne) { + String errMsg = "Attribute " + attr + " in Element " + elt.getTagName() + " is not a valid integer: " + strVal; + throw new DecisionTreeXmlException(errMsg); + } + } + + private String getDoubleAttribute(Element elt, String attr, boolean reqd) { + String strVal = getAttribute(elt, attr, reqd); + try { + if (strVal != null) + Double.parseDouble(strVal); + return strVal; + } catch (NumberFormatException ne) { + String errMsg = "Attribute " + attr + " in Element " + elt.getTagName() + " is not a valid double: " + strVal; + throw new DecisionTreeXmlException(errMsg); + } + } + + private Element getFirstChildElementByName(Element parent, String childName, boolean reqd) { + Element elt = XmlUtils.getFirstChildElementByName(parent, childName); + if (elt == null && reqd) + throw new DecisionTreeXmlException(elt.getTagName() + ": missing required element " + childName); + return elt; + } + + private static void logErrors(String msg) { + logger.severe(msg); + System.out.println(msg); + } + + private static void logErrors(Exception ex, String msg) { + String errMsg = ex.getClass().getName() + " " + ex.getMessage() + ": " + msg; + logger.severe(errMsg); + System.out.println(errMsg); + } + + public static void main(String[] args) { + String fileName = "C:\\yst\\libMLR_framework\\mlr3135.xml"; + new MlrXmlParser().parseXmlFile(fileName); + } + +} diff --git a/libmlr/src/main/java/config/header_template.txt b/libmlr/src/main/java/config/header_template.txt new file mode 100644 index 00000000000..16b6feaed2c --- /dev/null +++ b/libmlr/src/main/java/config/header_template.txt @@ -0,0 +1,17 @@ +/** + * File: {0} + * Package: search/secore/libs/mlr + * Description: + * + * Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + * + * Confidential + * + * Conversion Time: {1,date,long} {1,time,long} + * + * MODEL_SIZE: {2} trees x {3} leaf nodes + */ +#include "mlrfeatures.h" +#include "mlrscorereq.h" +#include "mlrfns.h" + |