diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo')
16 files changed, 111 insertions, 317 deletions
diff --git a/config-model/src/main/java/com/yahoo/documentmodel/NewDocumentType.java b/config-model/src/main/java/com/yahoo/documentmodel/NewDocumentType.java index d98869e9dd3..4ff54d7ff1c 100644 --- a/config-model/src/main/java/com/yahoo/documentmodel/NewDocumentType.java +++ b/config-model/src/main/java/com/yahoo/documentmodel/NewDocumentType.java @@ -4,8 +4,10 @@ package com.yahoo.documentmodel; import com.yahoo.document.DataType; import com.yahoo.document.Document; import com.yahoo.document.Field; +import com.yahoo.document.ReferenceDataType; import com.yahoo.document.StructDataType; import com.yahoo.document.StructuredDataType; +import com.yahoo.document.TemporaryStructuredDataType; import com.yahoo.document.annotation.AnnotationType; import com.yahoo.document.annotation.AnnotationTypeRegistry; import com.yahoo.document.datatypes.FieldValue; @@ -383,4 +385,18 @@ public final class NewDocumentType extends StructuredDataType implements DataTyp } + private ReferenceDataType refToThis = null; + + @SuppressWarnings("deprecation") + public ReferenceDataType getReferenceDataType() { + if (refToThis == null) { + // super ugly, the APIs for this are horribly inconsistent + var tmptmp = TemporaryStructuredDataType.create(getName()); + var tmp = ReferenceDataType.createWithInferredId(tmptmp); + tmp.setTargetType((StructuredDataType) this); + refToThis = tmp; + } + return refToThis; + } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Application.java b/config-model/src/main/java/com/yahoo/searchdefinition/Application.java index 16eef798acd..64688a7e70d 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/Application.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/Application.java @@ -84,7 +84,6 @@ public class Application { List<Schema> schemasSomewhatOrdered = new ArrayList<>(schemas); for (Schema schema : new SearchOrderer().order(schemasSomewhatOrdered)) { - new FieldOperationApplierForStructs().processSchemaFields(schema); new FieldOperationApplierForSearch().process(schema); // TODO: Why is this not in the regular list? new Processing(properties).process(schema, logger, diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/DocumentModelBuilder.java b/config-model/src/main/java/com/yahoo/searchdefinition/DocumentModelBuilder.java index 4a449dc898f..2d9c81085fe 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/DocumentModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/DocumentModelBuilder.java @@ -292,17 +292,8 @@ public class DocumentModelBuilder { else if (type instanceof ReferenceDataType) { ReferenceDataType t = (ReferenceDataType) type; var tt = t.getTargetType(); - if (tt instanceof TemporaryStructuredDataType) { - DataType targetType = resolveTemporariesRecurse(tt, repo, docs, replacements); - t.setTargetType((StructuredDataType) targetType); - } else if (tt instanceof DocumentType) { - DataType targetType = resolveTemporariesRecurse(tt, repo, docs, replacements); - // super ugly, the APIs for this are horribly inconsistent - var tmptmp = TemporaryStructuredDataType.create(tt.getName()); - var tmp = new ReferenceDataType(tmptmp, t.getId()); - tmp.setTargetType((StructuredDataType) targetType); - type = tmp; - } + var doc = getDocumentType(docs, tt.getId()); + type = doc.getReferenceDataType(); } if (type != original) { replacements.add(new TypeReplacement(original, type)); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForStructs.java b/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForStructs.java index 5e5623e2319..4a5a858f828 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForStructs.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForStructs.java @@ -20,47 +20,8 @@ public class FieldOperationApplierForStructs extends FieldOperationApplier { for (SDDocumentType type : sdoc.getAllTypes()) { if (type.isStruct()) { apply(type); - copyFields(type, sdoc); } } } - @SuppressWarnings("deprecation") - private void copyFields(SDDocumentType structType, SDDocumentType sdoc) { - //find all fields in OTHER types that have this type: - List<SDDocumentType> list = new ArrayList<>(); - list.add(sdoc); - list.addAll(sdoc.getTypes()); - for (SDDocumentType anyType : list) { - Iterator<Field> fields = anyType.fieldIterator(); - while (fields.hasNext()) { - SDField field = (SDField) fields.next(); - maybePopulateField(sdoc, field, structType); - } - } - } - - private void maybePopulateField(SDDocumentType sdoc, SDField field, SDDocumentType structType) { - DataType structUsedByField = field.getFirstStructRecursive(); - if (structUsedByField == null) { - return; - } - if (structUsedByField.getName().equals(structType.getName())) { - //this field is using this type!! - field.populateWithStructFields(sdoc, field.getName(), field.getDataType(), 0); - field.populateWithStructMatching(sdoc, field.getDataType(), field.getMatching()); - } - } - - public void processSchemaFields(Schema schema) { - var sdoc = schema.getDocument(); - if (sdoc == null) return; - for (SDDocumentType type : sdoc.getAllTypes()) { - if (type.isStruct()) { - for (SDField field : schema.allExtraFields()) { - maybePopulateField(sdoc, field, type); - } - } - } - } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/SDDocumentType.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/SDDocumentType.java index a037d5046a7..b87bdd8907e 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/document/SDDocumentType.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/SDDocumentType.java @@ -268,8 +268,8 @@ public class SDDocumentType implements Cloneable, Serializable { return field; } - public Field addField(String string, DataType dataType, boolean header, int code) { - SDField field = new SDField(this, string, code, dataType, header); + public Field addField(String fName, DataType dataType, boolean header, int code) { + SDField field = new SDField(this, fName, code, dataType); addField(field); return field; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java index bd6625d4bde..8263352e87f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java @@ -111,8 +111,8 @@ public class SDField extends Field implements TypedKey, FieldOperationContainer, /** Struct fields defined in this field */ private final Map<String,SDField> structFields = new java.util.LinkedHashMap<>(0); - /** The document that this field was declared in, or null*/ - private SDDocumentType ownerDocType = null; + /** The document that this field was declared in, or null */ + private SDDocumentType repoDocType = null; /** The aliases declared for this field. May pertain to indexes or attributes */ private final Map<String, String> aliasToName = new HashMap<>(); @@ -130,25 +130,24 @@ public class SDField extends Field implements TypedKey, FieldOperationContainer, * @param name the name of the field * @param dataType the datatype of the field */ - protected SDField(SDDocumentType repo, String name, int id, DataType dataType, boolean populate) { + public SDField(SDDocumentType repo, String name, int id, DataType dataType) { super(name, id, dataType); - populate(populate, repo, name, dataType); + this.repoDocType = repo; + populate(name, dataType); } - public SDField(SDDocumentType repo, String name, int id, DataType dataType) { - this(repo, name, id, dataType, true); + public SDField(String name, DataType dataType) { + this(null, name, dataType); } /** Creates a new field */ - public SDField(SDDocumentType repo, String name, DataType dataType, boolean populate) { - super(name, dataType); - populate(populate, repo, name, dataType); + public SDField(SDDocumentType repo, String name, DataType dataType) { + this(repo, name, dataType, null); } /** Creates a new field */ - protected SDField(SDDocumentType repo, String name, DataType dataType, SDDocumentType owner, boolean populate) { - super(name, dataType, owner == null ? null : owner.getDocumentType()); - populate(populate, repo, name, dataType); + protected SDField(SDDocumentType repo, String name, DataType dataType, SDDocumentType owner) { + this(repo, name, dataType, owner, null, 0); } /** @@ -159,27 +158,24 @@ public class SDField extends Field implements TypedKey, FieldOperationContainer, * @param owner the owning document (used to check for id collisions) * @param fieldMatching the matching object to set for the field */ - protected SDField(SDDocumentType repo, String name, DataType dataType, SDDocumentType owner, - Matching fieldMatching, boolean populate, int recursion) { + protected SDField(SDDocumentType repo, + String name, + DataType dataType, + SDDocumentType owner, + Matching fieldMatching, + int recursion) + { super(name, dataType, owner == null ? null : owner.getDocumentType()); + this.repoDocType = repo; + this.structFieldDepth = recursion; if (fieldMatching != null) this.setMatching(fieldMatching); - populate(populate, repo, name, dataType, fieldMatching, recursion); + populate(name, dataType); } - public SDField(SDDocumentType repo, String name, DataType dataType) { - this(repo, name, dataType, true); - } + private int structFieldDepth = 0; - public SDField(String name, DataType dataType) { - this(null, name, dataType); - } - - private void populate(boolean populate, SDDocumentType repo, String name, DataType dataType) { - populate(populate, repo, name, dataType, null, 0); - } - - private void populate(boolean populate, SDDocumentType repo, String name, DataType dataType, Matching fieldMatching, int recursion) { + private void populate(String name, DataType dataType) { if (dataType instanceof TensorDataType) { TensorType type = ((TensorDataType)dataType).getTensorType(); if (type.dimensions().stream().anyMatch(d -> d.isIndexed() && d.size().isEmpty())) @@ -194,10 +190,6 @@ public class SDField extends Field implements TypedKey, FieldOperationContainer, else { addQueryCommand("type " + dataType.getName()); } - if (populate || (dataType instanceof MapDataType)) { - populateWithStructFields(repo, name, dataType, recursion); - populateWithStructMatching(repo, dataType, fieldMatching); - } } public void setIsExtraField(boolean isExtra) { @@ -273,17 +265,23 @@ public class SDField extends Field implements TypedKey, FieldOperationContainer, } } + private boolean doneStructFields = false; + @SuppressWarnings("deprecation") - public void populateWithStructFields(SDDocumentType sdoc, String name, DataType dataType, int recursion) { - DataType dt = getFirstStructOrMapRecursive(); - if (dt == null) return; + private void actuallyMakeStructFields() { + if (doneStructFields) return; + if (getFirstStructOrMapRecursive() == null) { + doneStructFields = true; + return; + } + var sdoc = repoDocType; + var dataType = getDataType(); java.util.function.BiConsumer<String, DataType> supplyStructField = (fieldName, fieldType) -> { if (structFields.containsKey(fieldName)) return; - String subName = name.concat(".").concat(fieldName); - var subField = new SDField(sdoc, subName, fieldType, - ownerDocType, new Matching(), - true, recursion + 1); + String subName = getName().concat(".").concat(fieldName); + var subField = new SDField(sdoc, subName, fieldType, null, + null, structFieldDepth + 1); structFields.put(fieldName, subField); }; @@ -292,15 +290,16 @@ public class SDField extends Field implements TypedKey, FieldOperationContainer, supplyStructField.accept("key", mdt.getKeyType()); supplyStructField.accept("value", mdt.getValueType()); } else { - if (recursion >= 10) return; + if (structFieldDepth >= 10) { + // too risky, infinite recursion + doneStructFields = true; + return; + } if (dataType instanceof CollectionDataType) { dataType = ((CollectionDataType)dataType).getNestedType(); } - if (dataType instanceof TemporaryStructuredDataType) { - SDDocumentType subType = sdoc != null ? sdoc.getType(dataType.getName()) : null; - if (subType == null) { - throw new IllegalArgumentException("Could not find struct '" + dataType.getName() + "'."); - } + SDDocumentType subType = sdoc != null ? sdoc.getType(dataType.getName()) : null; + if (dataType instanceof TemporaryStructuredDataType && subType != null) { for (Field field : subType.fieldSet()) { supplyStructField.accept(field.getName(), field.getDataType()); } @@ -310,37 +309,23 @@ public class SDField extends Field implements TypedKey, FieldOperationContainer, supplyStructField.accept(field.getName(), field.getDataType()); } } - } - } - - public void populateWithStructMatching(SDDocumentType sdoc, DataType dataType, Matching superFieldMatching) { - if (sdoc == null) return; - if (superFieldMatching == null) return; - DataType dt = getFirstStructOrMapRecursive(); - if (dt == null) return; - - if (dataType instanceof MapDataType) { - // old code here would never do anything useful, should we do something here? - return; - } else { - if (dataType instanceof CollectionDataType) { - dataType = ((CollectionDataType)dataType).getNestedType(); + if ((subType == null) && (structFields.size() > 0)) { + throw new IllegalArgumentException("Cannot find matching (repo=" + sdoc + ") for subfields in " + + this + " [" + getDataType() + getDataType().getClass() + + "] with " + structFields.size() + " struct fields"); } - if (dataType instanceof StructDataType) { - SDDocumentType subType = sdoc.getType(dataType.getName()); - if (subType == null) { - throw new IllegalArgumentException("Could not find struct " + dataType.getName()); - } + // populate struct fields with matching + if (subType != null) { for (Field f : subType.fieldSet()) { if (f instanceof SDField) { SDField field = (SDField) f; Matching subFieldMatching = new Matching(); - subFieldMatching.merge(superFieldMatching); + subFieldMatching.merge(this.matching); subFieldMatching.merge(field.getMatching()); SDField subField = structFields.get(field.getName()); if (subField != null) { - subFieldMatching.merge(subField.getMatching()); - subField.populateWithStructMatching(sdoc, field.getDataType(), subFieldMatching); + // we just made this with no matching, so nop: + // subFieldMatching.merge(subField.getMatching()); subField.setMatching(subFieldMatching); } } else { @@ -349,8 +334,11 @@ public class SDField extends Field implements TypedKey, FieldOperationContainer, } } } + doneStructFields = true; } + private Matching matchingForStructFields = null; + public void addOperation(FieldOperation op) { pendingOperations.add(op); } @@ -723,7 +711,10 @@ public class SDField extends Field implements TypedKey, FieldOperationContainer, /** Returns list of static struct fields */ @Override - public Collection<SDField> getStructFields() { return structFields.values(); } + public Collection<SDField> getStructFields() { + actuallyMakeStructFields(); + return structFields.values(); + } /** * Returns a struct field defined in this field, @@ -732,6 +723,7 @@ public class SDField extends Field implements TypedKey, FieldOperationContainer, */ @Override public SDField getStructField(String name) { + actuallyMakeStructFields(); if (name.contains(".")) { String superFieldName = name.substring(0,name.indexOf(".")); String subFieldName = name.substring(name.indexOf(".")+1); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/TemporarySDField.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/TemporarySDField.java index 4ced104fa55..8c17b607f94 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/document/TemporarySDField.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/TemporarySDField.java @@ -8,12 +8,12 @@ import com.yahoo.document.DataType; */ public class TemporarySDField extends SDField { - public TemporarySDField(String name, DataType dataType, SDDocumentType owner) { - super(owner, name, dataType, owner, false); + public TemporarySDField(SDDocumentType repo, String name, DataType dataType, SDDocumentType owner) { + super(repo, name, dataType, owner); } - public TemporarySDField(String name, DataType dataType) { - super(null, name, dataType, false); + public TemporarySDField(SDDocumentType repo, String name, DataType dataType) { + super(repo, name, dataType); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/fieldoperation/IndexingOperation.java b/config-model/src/main/java/com/yahoo/searchdefinition/fieldoperation/IndexingOperation.java index fe3ac11af27..a5f5f961ab5 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/fieldoperation/IndexingOperation.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/fieldoperation/IndexingOperation.java @@ -24,6 +24,8 @@ public class IndexingOperation implements FieldOperation { this.script = script; } + public ScriptExpression getScript() { return script; } + public void apply(SDField field) { field.setIndexingScript(script); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedFields.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedFields.java index 36c11b33b23..caeebd65f4f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedFields.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedFields.java @@ -291,13 +291,12 @@ public class ConvertParsedFields { schema.addIndex(index); } - SDDocumentType convertStructDeclaration(Schema schema, ParsedStruct parsed) { + SDDocumentType convertStructDeclaration(Schema schema, SDDocumentType document, ParsedStruct parsed) { // TODO - can we cleanup this mess var structProxy = new SDDocumentType(parsed.name(), schema); - structProxy.setStruct(context.resolveStruct(parsed)); for (var parsedField : parsed.getFields()) { var fieldType = context.resolveType(parsedField.getType()); - var field = new SDField(structProxy, parsedField.name(), fieldType); + var field = new SDField(document, parsedField.name(), fieldType); convertCommonFieldSettings(field, parsedField); structProxy.addField(field); if (parsedField.hasIdOverride()) { @@ -307,6 +306,7 @@ public class ConvertParsedFields { for (String inherit : parsed.getInherited()) { structProxy.inherit(new DataTypeName(inherit)); } + structProxy.setStruct(context.resolveStruct(parsed)); return structProxy; } @@ -314,7 +314,7 @@ public class ConvertParsedFields { var annType = context.resolveAnnotation(parsed.name()); var payload = parsed.getStruct(); if (payload.isPresent()) { - var structProxy = convertStructDeclaration(schema, payload.get()); + var structProxy = convertStructDeclaration(schema, document, payload.get()); document.addType(structProxy); } document.addAnnotation(annType); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertSchemaCollection.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertSchemaCollection.java index 67e6c88d043..2d9a788cfef 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertSchemaCollection.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertSchemaCollection.java @@ -159,7 +159,7 @@ public class ConvertSchemaCollection { document.inherit(parent); } for (var struct : parsed.getStructs()) { - var structProxy = fieldConverter.convertStructDeclaration(schema, struct); + var structProxy = fieldConverter.convertStructDeclaration(schema, document, struct); document.addType(structProxy); } for (var annotation : parsed.getAnnotations()) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/AddExtraFieldsToDocument.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/AddExtraFieldsToDocument.java index 51defffa00b..0be48d1fd25 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/AddExtraFieldsToDocument.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/AddExtraFieldsToDocument.java @@ -70,7 +70,7 @@ public class AddExtraFieldsToDocument extends Processor { if (docField == null) { ImmutableSDField existingField = schema.getField(field.getName()); if (existingField == null) { - SDField newField = new SDField(document, field.getName(), field.getDataType(), true); + SDField newField = new SDField(document, field.getName(), field.getDataType()); newField.setIsExtraField(true); document.addField(newField); } else if (!existingField.isImportedField()) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/CreatePositionZCurve.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/CreatePositionZCurve.java index 0bb1b7da769..d7882c7f8fb 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/CreatePositionZCurve.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/CreatePositionZCurve.java @@ -10,6 +10,7 @@ import com.yahoo.document.PositionDataType; import com.yahoo.searchdefinition.Schema; import com.yahoo.searchdefinition.document.Attribute; import com.yahoo.searchdefinition.document.GeoPos; +import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.vespa.documentmodel.SummaryField; import com.yahoo.vespa.documentmodel.SummaryTransform; @@ -35,8 +36,11 @@ import java.util.logging.Level; */ public class CreatePositionZCurve extends Processor { + private final SDDocumentType repo; + public CreatePositionZCurve(Schema schema, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) { super(schema, deployLogger, rankProfileRegistry, queryProfiles); + this.repo = schema.getDocument(); } private boolean useV8GeoPositions = false; @@ -105,7 +109,7 @@ public class CreatePositionZCurve extends Processor { "' already created."); } boolean isArray = inputField.getDataType() instanceof ArrayDataType; - SDField field = new SDField(fieldName, isArray ? DataType.getArray(DataType.LONG) : DataType.LONG); + SDField field = new SDField(repo, fieldName, isArray ? DataType.getArray(DataType.LONG) : DataType.LONG); Attribute attribute = new Attribute(fieldName, Attribute.Type.LONG, isArray ? Attribute.CollectionType.ARRAY : Attribute.CollectionType.SINGLE); attribute.setPosition(true); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/UriHack.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/UriHack.java index 84dc6d369fc..7397f9a289c 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/UriHack.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/UriHack.java @@ -61,7 +61,7 @@ public class UriHack extends Processor { String partName = uriName + "." + suffix; // I wonder if this is explicit in qrs or implicit in backend? // search.addFieldSetItem(uriName, partName); - SDField partField = new SDField(partName, generatedType); + SDField partField = new SDField(schema.getDocument(), partName, generatedType); partField.setIndexStructureField(uriField.doesIndexing()); partField.setRankType(uriField.getRankType()); partField.setStemming(Stemming.NONE); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index 8390cc59b6f..ac99bee93ed 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -108,6 +108,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri public static final Logger log = Logger.getLogger(VespaModel.class.getName()); private final Version version; + private final Version wantedNodeVersion; private final ConfigModelRepo configModelRepo = new ConfigModelRepo(); private final AllocatedHosts allocatedHosts; @@ -170,6 +171,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri throws IOException, SAXException { super("vespamodel"); version = deployState.getVespaVersion(); + wantedNodeVersion = deployState.getWantedNodeVespaVersion(); fileReferencesRepository = new FileReferencesRepository(deployState.getFileRegistry()); rankingConstants = new RankingConstants(deployState.getFileRegistry(), Optional.empty()); validationOverrides = deployState.validationOverrides(); @@ -407,6 +409,11 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri return version; } + @Override + public Version wantedNodeVersion() { + return wantedNodeVersion; + } + /** * Resolves config of the given type and config id, by first instantiating the correct {@link com.yahoo.config.ConfigInstance.Builder}, * calling {@link #getConfig(com.yahoo.config.ConfigInstance.Builder, String)}. The default values used will be those of the config diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java index 2742dc59fcd..88139de7888 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java @@ -36,15 +36,13 @@ import java.util.stream.Collectors; */ public class OnnxModelInfo { - private final ApplicationPackage app; private final String modelPath; private final String defaultOutput; private final Map<String, OnnxTypeInfo> inputs; private final Map<String, OnnxTypeInfo> outputs; private final Map<String, TensorType> vespaTypes = new HashMap<>(); - private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) { - this.app = app; + private OnnxModelInfo(String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) { this.modelPath = path; this.inputs = Collections.unmodifiableMap(inputs); this.outputs = Collections.unmodifiableMap(outputs); @@ -81,15 +79,7 @@ public class OnnxModelInfo { Set<Long> unboundSizes = new HashSet<>(); Map<String, Long> symbolicSizes = new HashMap<>(); resolveUnknownDimensionSizes(inputTypes, symbolicSizes, unboundSizes); - - TensorType type = TensorType.empty; - if (inputTypes.size() > 0 && onnxTypeInfo.needModelProbe(symbolicSizes)) { - type = OnnxModelProbe.probeModel(app, Path.fromString(modelPath), onnxName, inputTypes); - } - if (type.equals(TensorType.empty)) { - type = onnxTypeInfo.toVespaTensorType(symbolicSizes, unboundSizes); - } - return type; + return onnxTypeInfo.toVespaTensorType(symbolicSizes, unboundSizes); } return vespaTypes.computeIfAbsent(onnxName, v -> onnxTypeInfo.toVespaTensorType()); } @@ -160,8 +150,7 @@ public class OnnxModelInfo { Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); String json = onnxModelToJson(model, path); storeGeneratedInfo(json, path, app); - return jsonToModelInfo(json, app); - + return jsonToModelInfo(json); } catch (IOException e) { throw new IllegalArgumentException("Unable to parse ONNX model", e); } @@ -170,7 +159,7 @@ public class OnnxModelInfo { static private OnnxModelInfo loadFromGeneratedInfo(Path path, ApplicationPackage app) { try { String json = readGeneratedInfo(path, app); - return jsonToModelInfo(json, app); + return jsonToModelInfo(json); } catch (IOException e) { throw new IllegalArgumentException("Unable to parse ONNX model", e); } @@ -213,7 +202,7 @@ public class OnnxModelInfo { return out.toString(); } - static public OnnxModelInfo jsonToModelInfo(String json, ApplicationPackage app) throws IOException { + static public OnnxModelInfo jsonToModelInfo(String json) throws IOException { ObjectMapper m = new ObjectMapper(); JsonNode root = m.readTree(json); Map<String, OnnxTypeInfo> inputs = new HashMap<>(); @@ -233,7 +222,7 @@ public class OnnxModelInfo { if (root.get("outputs").has(0)) { defaultOutput = root.get("outputs").get(0).get("name").textValue(); } - return new OnnxModelInfo(app, path, inputs, outputs, defaultOutput); + return new OnnxModelInfo(path, inputs, outputs, defaultOutput); } static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException { @@ -364,21 +353,6 @@ public class OnnxModelInfo { return builder.build(); } - boolean needModelProbe(Map<String, Long> symbolicSizes) { - for (OnnxDimensionInfo onnxDimension : dimensions) { - if (onnxDimension.hasSymbolicName()) { - if (symbolicSizes == null) - return true; - if ( ! symbolicSizes.containsKey(onnxDimension.getSymbolicName())) { - return true; - } - } else if (onnxDimension.getSize() == 0) { - return true; - } - } - return false; - } - @Override public String toString() { return "(" + valueType.id() + ")" + diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java deleted file mode 100644 index 2e2ebdeb98f..00000000000 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java +++ /dev/null @@ -1,152 +0,0 @@ -package com.yahoo.vespa.model.ml; - -import com.fasterxml.jackson.core.JsonEncoding; -import com.fasterxml.jackson.core.JsonFactory; -import com.fasterxml.jackson.core.JsonGenerator; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.yahoo.config.application.api.ApplicationFile; -import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.io.IOUtils; -import com.yahoo.path.Path; -import com.yahoo.tensor.TensorType; - -import java.io.BufferedReader; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.nio.charset.StandardCharsets; -import java.util.Map; - -/** - * Defers to 'vespa-analyze-onnx-model' to determine the output type given - * a set of inputs. For situations with symbolic dimension sizes that can't - * easily be determined. - * - * @author lesters - */ -public class OnnxModelProbe { - - private static final String binary = "vespa-analyze-onnx-model"; - - static TensorType probeModel(ApplicationPackage app, Path modelPath, String outputName, Map<String, TensorType> inputTypes) { - TensorType outputType = TensorType.empty; - String contextKey = createContextKey(outputName, inputTypes); - - try { - // Check if output type has already been probed - outputType = readProbedOutputType(app, modelPath, contextKey); - - // Otherwise, run vespa-analyze-onnx-model if the model is available - if (outputType.equals(TensorType.empty) && app.getFile(modelPath).exists()) { - String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes); - String jsonOutput = callVespaAnalyzeOnnxModel(jsonInput); - outputType = outputTypeFromJson(jsonOutput, outputName); - if ( ! outputType.equals(TensorType.empty)) { - writeProbedOutputType(app, modelPath, contextKey, outputType); - } - } - - } catch (IllegalArgumentException | IOException | InterruptedException e) { - e.printStackTrace(System.err); - } - - return outputType; - } - - private static String createContextKey(String onnxName, Map<String, TensorType> inputTypes) { - StringBuilder key = new StringBuilder().append(onnxName).append(":"); - inputTypes.entrySet().stream().sorted(Map.Entry.comparingByKey()) - .forEachOrdered(e -> key.append(e.getKey()).append(":").append(e.getValue()).append(",")); - return key.substring(0, key.length()-1); - } - - private static Path probedOutputTypesPath(Path path) { - String fileName = OnnxModelInfo.asValidIdentifier(path.getRelative()) + ".probed_output_types"; - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName); - } - - static void writeProbedOutputType(ApplicationPackage app, Path modelPath, String output, - Map<String, TensorType> inputTypes, TensorType type) throws IOException { - writeProbedOutputType(app, modelPath, createContextKey(output, inputTypes), type); - } - - private static void writeProbedOutputType(ApplicationPackage app, Path modelPath, - String contextKey, TensorType type) throws IOException { - String path = app.getFileReference(probedOutputTypesPath(modelPath)).getAbsolutePath(); - IOUtils.writeFile(path, contextKey + "\t" + type + "\n", true); - } - - private static TensorType readProbedOutputType(ApplicationPackage app, Path modelPath, - String contextKey) throws IOException { - ApplicationFile file = app.getFile(probedOutputTypesPath(modelPath)); - if ( ! file.exists()) { - return TensorType.empty; - } - try (BufferedReader reader = new BufferedReader(file.createReader())) { - String line; - while (null != (line = reader.readLine())) { - String[] parts = line.split("\t"); - String key = parts[0]; - if (key.equals(contextKey)) { - return TensorType.fromSpec(parts[1]); - } - } - } - return TensorType.empty; - } - - private static TensorType outputTypeFromJson(String json, String outputName) throws IOException { - ObjectMapper m = new ObjectMapper(); - JsonNode root = m.readTree(json); - if ( ! root.isObject() || ! root.has("outputs")) { - return TensorType.empty; - } - JsonNode outputs = root.get("outputs"); - if ( ! outputs.has(outputName)) { - return TensorType.empty; - } - return TensorType.fromSpec(outputs.get(outputName).asText()); - } - - private static String createJsonInput(String modelPath, Map<String, TensorType> inputTypes) throws IOException { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8); - g.writeStartObject(); - g.writeStringField("model", modelPath); - g.writeObjectFieldStart("inputs"); - for (Map.Entry<String, TensorType> input : inputTypes.entrySet()) { - g.writeStringField(input.getKey(), input.getValue().toString()); - } - g.writeEndObject(); - g.writeEndObject(); - g.close(); - return out.toString(); - } - - private static String callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException { - ProcessBuilder processBuilder = new ProcessBuilder(binary, "--probe-types"); - StringBuilder output = new StringBuilder(); - Process process = processBuilder.start(); - - // Write json array to process stdin - OutputStream os = process.getOutputStream(); - os.write(jsonInput.getBytes(StandardCharsets.UTF_8)); - os.close(); - - // Read output from stdout - InputStream inputStream = process.getInputStream(); - while (true) { - int b = inputStream.read(); - if (b == -1) break; - output.append((char)b); - } - int returnCode = process.waitFor(); - if (returnCode != 0) { - throw new IllegalArgumentException("Error from '" + binary + "'. Return code: " + returnCode + ". Output:\n" + output); - } - return output.toString(); - } - -} |