diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-11-25 20:07:56 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-11-25 20:07:56 +0100 |
commit | 1d88554bd513783715425120e76fc5f2a86f439f (patch) | |
tree | 166c86107d3620014cc7e26d85118c311e1b8cf0 | |
parent | a01bc21d9bcbc417a9fb2591079561f59f76865e (diff) |
Java type only interface between imported-models and config models
This avoids class incompatibility problems when passing an
imported model across bundle boundaries to a config model.
Tensor string parsing has been sped up as this relies on it more.
10 files changed, 176 insertions, 68 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 59aa5b3ba53..259ac5227ae 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -202,8 +202,9 @@ public class ConvertedModel { // Add expressions Map<String, ExpressionFunction> expressions = new HashMap<>(); - for (Pair<String, ExpressionFunction> output : model.outputExpressions()) { - addExpression(output.getSecond(), output.getFirst(), + for (ImportedModel.ImportedFunction outputFunction : model.outputExpressions()) { + ExpressionFunction expression = asExpressionFunction(outputFunction); + addExpression(expression, expression.getName(), constantsReplacedByFunctions, model, store, profile, queryProfiles, expressions); @@ -218,6 +219,23 @@ public class ConvertedModel { return expressions; } + private static ExpressionFunction asExpressionFunction(ImportedModel.ImportedFunction function) { + try { + Map<String, TensorType> argumentTypes = new HashMap<>(); + for (Map.Entry<String, String> entry : function.argumentTypes().entrySet()) + argumentTypes.put(entry.getKey(), TensorType.fromSpec(entry.getValue())); + + return new ExpressionFunction(function.name(), + function.arguments(), + new RankingExpression(function.expression()), + argumentTypes, + function.returnType().map(TensorType::fromSpec)); + } + catch (ParseException e) { + throw new IllegalArgumentException("Gor an illegal argument from importing " + function.name(), e); + } + } + private static void addExpression(ExpressionFunction expression, String expressionName, Set<String> constantsReplacedByFunctions, @@ -248,7 +266,9 @@ public class ConvertedModel { return store.readExpressions(); } - private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { + private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, + String constantValueString) { + Tensor constantValue = Tensor.from(constantValueString); store.writeSmallConstant(constantName, constantValue); profile.addConstant(constantName, asValue(constantValue)); } @@ -258,7 +278,8 @@ public class ConvertedModel { QueryProfileRegistry queryProfiles, Set<String> constantsReplacedByFunctions, String constantName, - Tensor constantValue) { + String constantValueString) { + Tensor constantValue = Tensor.from(constantValueString); RankProfile.RankingExpressionFunction rankingExpressionFunctionOverridingConstant = profile.getFunctions().get(constantName); if (rankingExpressionFunctionOverridingConstant != null) { TensorType functionType = rankingExpressionFunctionOverridingConstant.function().getBody().type(profile.typeContext(queryProfiles)); @@ -306,14 +327,14 @@ public class ConvertedModel { Set<String> functionNames = new HashSet<>(); addFunctionNamesIn(expression.getRoot(), functionNames, model); for (String functionName : functionNames) { - TensorType requiredType = model.inputs().get(functionName); - if (requiredType == null) continue; // Not a required function + Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec); + if ( ! requiredType.isPresent()) continue; // Not a required function RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName); if (rankingExpressionFunction == null) throw new IllegalArgumentException("Model refers input '" + functionName + - "' of type " + requiredType + " but this function is not present in " + - profile); + "' of type " + requiredType.get() + + " but this function is not present in " + profile); // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second // phase and summary features), as it may only resolve correctly given those bindings // Or, probably better, annotate the functions with type constraints here and verify during general @@ -321,12 +342,12 @@ public class ConvertedModel { TensorType actualType = rankingExpressionFunction.function().getBody().getRoot().type(profile.typeContext(queryProfiles)); if ( actualType == null) throw new IllegalArgumentException("Model refers input '" + functionName + - "' of type " + requiredType + + "' of type " + requiredType.get() + " which must be produced by a function in the rank profile, but " + "this function references a feature which is not declared"); - if ( ! actualType.isAssignableTo(requiredType)) + if ( ! actualType.isAssignableTo(requiredType.get())) throw new IllegalArgumentException("Model refers input '" + functionName + "'. " + - typeMismatchExplanation(requiredType, actualType)); + typeMismatchExplanation(requiredType.get(), actualType)); } } @@ -339,7 +360,7 @@ public class ConvertedModel { /** Add the generated functions to the rank profile */ private static void addGeneratedFunctions(ImportedModel model, RankProfile profile) { - model.functions().forEach((k, v) -> addGeneratedFunctionToProfile(profile, k, v.copy())); + model.functions().forEach((k, v) -> addGeneratedFunctionToProfile(profile, k, RankingExpression.from(v))); } /** @@ -383,7 +404,7 @@ public class ConvertedModel { List<ExpressionNode> children = ((TensorFunctionNode)node).children(); if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (model.inputs().containsKey(referenceNode.getName())) { + if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { return reduceBatchDimensionExpression(tensorFunction, typeContext); } } @@ -391,7 +412,7 @@ public class ConvertedModel { } if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) node; - if (model.inputs().containsKey(referenceNode.getName())) { + if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); } } @@ -487,7 +508,7 @@ public class ConvertedModel { if (referenceNode.getOutput() == null) { // function references cannot specify outputs names.add(referenceNode.getName()); if (model.functions().containsKey(referenceNode.getName())) { - addFunctionNamesIn(model.functions().get(referenceNode.getName()).getRoot(), names, model); + addFunctionNamesIn(RankingExpression.from(model.functions().get(referenceNode.getName())).getRoot(), names, model); } } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index 2866a2c76b2..c2235b9abe9 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.yahoo.collections.Pair; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; @@ -59,18 +60,29 @@ public class ImportedModel { /** Returns an immutable map of the inputs of this */ public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); } + // CFG + public Optional<String> inputTypeSpec(String input) { + return Optional.ofNullable(inputs.get(input)).map(TensorType::toString); + } + /** - * Returns an immutable map of the small constants of this. + * Returns an immutable map of the small constants of this, represented as strings on the standard tensor form. * These should have sizes up to a few kb at most, and correspond to constant values given in the source model. */ - public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); } + // CFG + public Map<String, String> smallConstants() { return asTensorStrings(smallConstants); } + + boolean hasSmallConstant(String name) { return smallConstants.containsKey(name); } /** * Returns an immutable map of the large constants of this. * These can have sizes in gigabytes and must be distributed to nodes separately from configuration. * For TensorFlow this corresponds to Variable files stored separately. */ - public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); } + // CFG + public Map<String, String> largeConstants() { return asTensorStrings(largeConstants); } + + boolean hasLargeConstant(String name) { return largeConstants.containsKey(name); } /** * Returns an immutable map of the expressions of this - corresponding to graph nodes @@ -79,11 +91,14 @@ public class ImportedModel { */ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } + // TODO: Most of the usage of the above can be replaced by a faster expressionNames method + /** * Returns an immutable map of the functions that are part of this model. * Note that the functions themselves are *not* copies and *not* immutable - they must be copied before modification. */ - public Map<String, RankingExpression> functions() { return Collections.unmodifiableMap(functions); } + // CFG + public Map<String, String> functions() { return asExpressionStrings(functions); } /** Returns an immutable map of the signatures of this */ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } @@ -108,43 +123,60 @@ public class ImportedModel { * if signatures are used, or the expression name if signatures are not used and there are multiple * expressions, and the second is the output name if signature names are used. */ - public List<Pair<String, ExpressionFunction>> outputExpressions() { - List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>(); + // CFG + public List<ImportedFunction> outputExpressions() { + List<ImportedFunction> functions = new ArrayList<>(); for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) { for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) - expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(), - signatureEntry.getValue().outputExpression(outputEntry.getKey()) - .withName(signatureEntry.getKey() + "." + outputEntry.getKey()))); + functions.add(signatureEntry.getValue().outputFunction(outputEntry.getKey(), + signatureEntry.getKey() + "." + outputEntry.getKey())); if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs - expressions.add(new Pair<>(signatureEntry.getKey(), - new ExpressionFunction(signatureEntry.getKey(), - new ArrayList<>(signatureEntry.getValue().inputs().values()), - expressions().get(signatureEntry.getKey()), - signatureEntry.getValue().inputMap(), - Optional.empty()))); + functions.add(new ImportedFunction(signatureEntry.getKey(), + new ArrayList<>(signatureEntry.getValue().inputs().values()), + expressions().get(signatureEntry.getKey()), + signatureEntry.getValue().inputMap(), + Optional.empty())); } if (signatures().isEmpty()) { // fallback for models without signatures if (expressions().size() == 1) { Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next(); - expressions.add(new Pair<>(singleEntry.getKey(), - new ExpressionFunction(singleEntry.getKey(), - new ArrayList<>(inputs.keySet()), - singleEntry.getValue(), - inputs, - Optional.empty()))); + functions.add(new ImportedFunction(singleEntry.getKey(), + new ArrayList<>(inputs.keySet()), + singleEntry.getValue(), + inputs, + Optional.empty())); } else { for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) { - expressions.add(new Pair<>(expressionEntry.getKey(), - new ExpressionFunction(expressionEntry.getKey(), - new ArrayList<>(inputs.keySet()), - expressionEntry.getValue(), - inputs, - Optional.empty()))); + functions.add(new ImportedFunction(expressionEntry.getKey(), + new ArrayList<>(inputs.keySet()), + expressionEntry.getValue(), + inputs, + Optional.empty())); } } } - return expressions; + return functions; + } + + private Map<String, String> asTensorStrings(Map<String, Tensor> map) { + HashMap<String, String> values = new HashMap<>(); + for (Map.Entry<String, Tensor> entry : map.entrySet()) { + Tensor tensor = entry.getValue(); + // TODO: See Tensor.toStandardString + if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty()) + values.put(entry.getKey(), tensor.toString()); + else + values.put(entry.getKey(), tensor.type() + ":" + tensor); + } + return values; + } + + private Map<String, String> asExpressionStrings(Map<String, RankingExpression> map) { + HashMap<String, String> values = new HashMap<>(); + for (Map.Entry<String, RankingExpression> entry : map.entrySet()) + values.put(entry.getKey(), entry.getValue().getRoot().toString()); + return values; } /** @@ -213,6 +245,17 @@ public class ImportedModel { Optional.empty()); } + /** Returns the expression this output references as an imported function */ + public ImportedFunction outputFunction(String outputName, String functionName) { + return new ImportedFunction(functionName, + new ArrayList<>(inputs.values()), + owner().expressions().get(outputs.get(outputName)), + inputMap(), + Optional.empty()); + } + + // CFG + @Override public String toString() { return "signature '" + name + "'"; } @@ -223,4 +266,37 @@ public class ImportedModel { } + // CFG + public static class ImportedFunction { + + private final String name; + private final List<String> arguments; + private final Map<String, String> argumentTypes; + private final String expression; + private final Optional<String> returnType; + + public ImportedFunction(String name, List<String> arguments, RankingExpression expression, + Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) { + this.name = name; + this.arguments = arguments; + this.expression = expression.getRoot().toString(); + this.argumentTypes = asStrings(argumentTypes); + this.returnType = returnType.map(TensorType::toString); + } + + private static Map<String, String> asStrings(Map<String, TensorType> map) { + Map<String, String> stringMap = new HashMap<>(); + for (Map.Entry<String, TensorType> entry : map.entrySet()) + stringMap.put(entry.getKey(), entry.getValue().toString()); + return stringMap; + } + + public String name() { return name; } + public List<String> arguments() { return Collections.unmodifiableList(arguments); } + public Map<String, String> argumentTypes() { return Collections.unmodifiableMap(argumentTypes); } + public String expression() { return expression; } + public Optional<String> returnType() { return returnType; } + + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java index 1b7532631e1..bfdaaca1dd7 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java @@ -69,6 +69,7 @@ public class ImportedModels { * models directory works * @return the model at this path or null if none */ + // CFG public ImportedModel get(File modelPath) { return importedModels.get(toName(modelPath)); } @@ -78,6 +79,7 @@ public class ImportedModels { } /** Returns an immutable collection of all the imported models */ + // CFG public Collection<ImportedModel> all() { return importedModels.values(); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index cb095e81147..8a885938bf9 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -121,7 +121,7 @@ public abstract class ModelImporter { private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) { String name = operation.vespaName(); - if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { + if (model.hasLargeConstant(name) || model.hasSmallConstant(name)) { return operation.function(); } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index d3996da9b58..315456c2613 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -28,13 +28,13 @@ public class OnnxMnistSoftmaxImportTestCase { // Check constants assertEquals(2, model.largeConstants().size()); - Tensor constant0 = model.largeConstants().get("test_Variable"); + Tensor constant0 = Tensor.from(model.largeConstants().get("test_Variable")); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.largeConstants().get("test_Variable_1"); + Tensor constant1 = Tensor.from(model.largeConstants().get("test_Variable_1")); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); assertEquals(10, constant1.size()); @@ -84,8 +84,8 @@ public class OnnxMnistSoftmaxImportTestCase { private Context contextFrom(ImportedModel result) { MapContext context = new MapContext(); - result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); - result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); return context; } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java index 6215997d8f9..be676186017 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java @@ -24,13 +24,13 @@ public class TensorFlowMnistSoftmaxImportTestCase { // Check constants Assert.assertEquals(2, model.get().largeConstants().size()); - Tensor constant0 = model.get().largeConstants().get("test_Variable_read"); + Tensor constant0 = Tensor.from(model.get().largeConstants().get("test_Variable_read")); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.get().largeConstants().get("test_Variable_1_read"); + Tensor constant1 = Tensor.from(model.get().largeConstants().get("test_Variable_1_read")); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java index c3b82cccb46..4ff0c96d369 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java @@ -93,8 +93,8 @@ public class TestableTensorFlowModel { static Context contextFrom(ImportedModel result) { TestableModelContext context = new TestableModelContext(); - result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); - result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); return context; } @@ -108,7 +108,7 @@ public class TestableTensorFlowModel { private void evaluateFunction(Context context, ImportedModel model, String functionName) { if (!context.names().contains(functionName)) { - RankingExpression e = model.functions().get(functionName); + RankingExpression e = RankingExpression.from(model.functions().get(functionName)); evaluateFunctionDependencies(context, model, e.getRoot()); context.put(functionName, new TensorValue(e.evaluate(context).asTensor())); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 483ccd330e0..1ee22c69c23 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -230,7 +230,7 @@ public interface Tensor { * @return the tensor on the standard string format */ static String toStandardString(Tensor tensor) { - if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty()) // explicitly output type TODO: Never do that? + if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty()) // explicitly output type TODO: Always do that return tensor.type() + ":" + contentToString(tensor); else return contentToString(tensor); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 000f33696f2..fa32d385004 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -76,32 +76,41 @@ class TensorParser { } private static Tensor fromCellString(Tensor.Builder builder, String s) { - s = s.trim().substring(1).trim(); - while (s.length() > 1) { - int keyOrTensorEnd = s.indexOf('}'); + int index = 1; + index = skipSpace(index, s); + while (index + 1 < s.length()) { + int keyOrTensorEnd = s.indexOf('}', index); TensorAddress.Builder addressBuilder = new TensorAddress.Builder(builder.type()); - if (keyOrTensorEnd < s.length() - 1) { // Key end: This has a key - otherwise TensorAdress is empty - addLabels(s.substring(0, keyOrTensorEnd + 1), addressBuilder); - s = s.substring(keyOrTensorEnd + 1).trim(); - if ( ! s.startsWith(":")) - throw new IllegalArgumentException("Expecting a ':' after " + s + ", got '" + s + "'"); - s = s.substring(1); + if (keyOrTensorEnd < s.length() - 1) { // Key end: This has a key - otherwise TensorAddress is empty + addLabels(s.substring(index, keyOrTensorEnd + 1), addressBuilder); + index = keyOrTensorEnd + 1; + index = skipSpace(index, s); + if ( s.charAt(index) != ':') + throw new IllegalArgumentException("Expecting a ':' after " + s.substring(index) + ", got '" + s + "'"); + index++; } - int valueEnd = s.indexOf(','); + int valueEnd = s.indexOf(',', index); if (valueEnd < 0) { // last value - valueEnd = s.indexOf("}"); + valueEnd = s.indexOf('}', index); if (valueEnd < 0) throw new IllegalArgumentException("A tensor string must end by '}'"); } TensorAddress address = addressBuilder.build(); - Double value = asDouble(address, s.substring(0, valueEnd).trim()); + Double value = asDouble(address, s.substring(index, valueEnd).trim()); builder.cell(address, value); - s = s.substring(valueEnd+1).trim(); + index = valueEnd+1; + index = skipSpace(index, s); } return builder.build(); } + private static int skipSpace(int index, String s) { + while (index < s.length() && s.charAt(index) == ' ') + index++; + return index; + } + /** Creates a tenor address from a string on the form {dimension1:label1,dimension2:label2,...} */ private static void addLabels(String mapAddressString, TensorAddress.Builder builder) { mapAddressString = mapAddressString.trim(); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 38a8329bff1..122b6019884 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -101,7 +101,7 @@ public class TensorTestCase { " {x:0,y:0,z:1}:1, {x:0,y:1,z:1}:2, {x:1,y:0,z:1}:2, {x:1,y:1,z:1}:3, {x:2,y:0,z:1}:3, {x:2,y:1,z:1}:4 }"), Tensor.range(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build())); assertEquals(Tensor.from("{ {x:0,y:0,z:0}:1, {x:0,y:1,z:0}:0, {x:1,y:0,z:0}:0, {x:1,y:1,z:0}:0, {x:2,y:0,z:0}:0, {x:2,y:1,z:0}:0, "+ - " {x:0,y:0,z:1}:0, {x:0,y:1,z:1}:0, {x:1,y:0,z:1}:0, {x:1,y:1,z:1}:1, {x:2,y:0,z:1}:0, {x:2,y:1,z:1}:00 }"), + " {x:0,y:0,z:1}:0, {x:0,y:1,z:1}:0, {x:1,y:0,z:1}:0, {x:1,y:1,z:1}:1, {x:2,y:0,z:1}:0, {x:2,y:1,z:1}:00 } "), Tensor.diag(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build())); assertEquals(Tensor.from("{ {x:1}:0, {x:3}:1, {x:9}:0 }"), Tensor.from("{ {x:1}:1, {x:3}:5, {x:9}:3 }").argmax("x")); } |