diff options
Diffstat (limited to 'config-model/src/main/java')
23 files changed, 672 insertions, 460 deletions
diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java index 51f3455762c..480b6590555 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java @@ -79,6 +79,7 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea private boolean useV8GeoPositions = false; private List<String> environmentVariables = List.of(); private boolean avoidRenamingSummaryFeatures = false; + private boolean experimentalSdParsing = false; @Override public ModelContext.FeatureFlags featureFlags() { return this; } @Override public boolean multitenant() { return multitenant; } @@ -138,6 +139,7 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea @Override public boolean useV8GeoPositions() { return useV8GeoPositions; } @Override public List<String> environmentVariables() { return environmentVariables; } @Override public boolean avoidRenamingSummaryFeatures() { return this.avoidRenamingSummaryFeatures; } + @Override public boolean experimentalSdParsing() { return this.experimentalSdParsing; } public TestProperties maxUnCommittedMemory(int maxUnCommittedMemory) { this.maxUnCommittedMemory = maxUnCommittedMemory; @@ -375,6 +377,11 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea return this; } + public TestProperties setExperimentalSdParsing(boolean value) { + this.experimentalSdParsing = value; + return this; + } + public static class Spec implements ConfigServerSpec { private final String hostName; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Application.java b/config-model/src/main/java/com/yahoo/searchdefinition/Application.java index 64688a7e70d..16eef798acd 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/Application.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/Application.java @@ -84,6 +84,7 @@ public class Application { List<Schema> schemasSomewhatOrdered = new ArrayList<>(schemas); for (Schema schema : new SearchOrderer().order(schemasSomewhatOrdered)) { + new FieldOperationApplierForStructs().processSchemaFields(schema); new FieldOperationApplierForSearch().process(schema); // TODO: Why is this not in the regular list? new Processing(properties).process(schema, logger, diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/ApplicationBuilder.java b/config-model/src/main/java/com/yahoo/searchdefinition/ApplicationBuilder.java index 533546b4d39..67d8d679275 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/ApplicationBuilder.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/ApplicationBuilder.java @@ -15,6 +15,9 @@ import com.yahoo.io.reader.NamedReader; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.search.query.profile.config.QueryProfileXMLReader; +import com.yahoo.searchdefinition.parser.ConvertSchemaCollection; +import com.yahoo.searchdefinition.parser.IntermediateCollection; +import com.yahoo.searchdefinition.parser.IntermediateParser; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchdefinition.parser.SDParser; import com.yahoo.searchdefinition.parser.SimpleCharStream; @@ -43,6 +46,7 @@ import java.util.Set; */ public class ApplicationBuilder { + private final IntermediateCollection mediator; private final ApplicationPackage applicationPackage; private final List<Schema> schemas = new ArrayList<>(); private final DocumentTypeManager documentTypeManager = new DocumentTypeManager(); @@ -98,10 +102,17 @@ public class ApplicationBuilder { this(rankProfileRegistry, queryProfileRegistry, new TestProperties()); } + /** For testing only */ + public ApplicationBuilder(ModelContext.Properties properties) { + this(new RankProfileRegistry(), new QueryProfileRegistry(), properties); + } + + /** For testing only */ public ApplicationBuilder(RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfileRegistry, ModelContext.Properties properties) { this(MockApplicationPackage.createEmpty(), new MockFileRegistry(), new BaseDeployLogger(), properties, rankProfileRegistry, queryProfileRegistry); } + /** normal constructor */ public ApplicationBuilder(ApplicationPackage app, FileRegistry fileRegistry, DeployLogger deployLogger, @@ -118,6 +129,7 @@ public class ApplicationBuilder { RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfileRegistry, boolean documentsOnly) { + this.mediator = new IntermediateCollection(deployLogger, properties); this.applicationPackage = applicationPackage; this.rankProfileRegistry = rankProfileRegistry; this.queryProfileRegistry = queryProfileRegistry; @@ -133,13 +145,17 @@ public class ApplicationBuilder { * Adds a schema to this application. * * @param fileName the name of the file to import - * @return the name of the imported object * @throws IOException thrown if the file can not be read for some reason * @throws ParseException thrown if the file does not contain a valid search definition */ - public Schema addSchemaFile(String fileName) throws IOException, ParseException { + public void addSchemaFile(String fileName) throws IOException, ParseException { + if (properties.featureFlags().experimentalSdParsing()) { + var parsedName = mediator.addSchemaFromFile(fileName); + addRankProfileFiles(parsedName); + return; + } File file = new File(fileName); - return addSchema(IOUtils.readFile(file)); + addSchema(IOUtils.readFile(file)); } /** @@ -149,8 +165,19 @@ public class ApplicationBuilder { * @param reader the reader whose content to import */ public void addSchema(NamedReader reader) { + if (properties.featureFlags().experimentalSdParsing()) { + try { + var parsedName = mediator.addSchemaFromReader(reader); + addRankProfileFiles(parsedName); + } catch (ParseException e) { + throw new IllegalArgumentException("Could not parse schema file '" + reader.getName() + "'", e); + } + return; + } try { - String schemaName = addSchema(IOUtils.readAll(reader)).getName(); + Schema schema = createSchema(IOUtils.readAll(reader)); + add(schema); + String schemaName = schema.getName(); String schemaFileName = stripSuffix(reader.getName(), ApplicationPackage.SD_NAME_SUFFIX); if ( ! schemaFileName.equals(schemaName)) { throw new IllegalArgumentException("The file containing schema '" + schemaName + "' must be named '" + @@ -176,8 +203,13 @@ public class ApplicationBuilder { * * @param schemaString the content of the schema */ - public Schema addSchema(String schemaString) throws ParseException { - return add(createSchema(schemaString)); + public void addSchema(String schemaString) throws ParseException { + if (properties.featureFlags().experimentalSdParsing()) { + var parsed = mediator.addSchemaFromString(schemaString); + addRankProfileFiles(parsed.name()); + return; + } + add(createSchema(schemaString)); } /** @@ -202,6 +234,9 @@ public class ApplicationBuilder { } private Schema parseSchema(String schemaString) throws ParseException { + if (properties.featureFlags().experimentalSdParsing()) { + throw new IllegalArgumentException("should use new parser only"); + } SimpleCharStream stream = new SimpleCharStream(schemaString); try { return parserOf(stream).schema(documentTypeManager); @@ -215,6 +250,10 @@ public class ApplicationBuilder { private void addRankProfileFiles(Schema schema) { if (applicationPackage == null) return; + if (properties.featureFlags().experimentalSdParsing()) { + throw new IllegalArgumentException("should use new parser only"); + } + Path legacyRankProfilePath = ApplicationPackage.SEARCH_DEFINITIONS_DIR.append(schema.getName()); for (NamedReader reader : applicationPackage.getFiles(legacyRankProfilePath, ".profile")) parseRankProfile(reader, schema); @@ -224,8 +263,28 @@ public class ApplicationBuilder { parseRankProfile(reader, schema); } + private void addRankProfileFiles(String schemaName) throws ParseException { + if (applicationPackage == null) return; + if (! properties.featureFlags().experimentalSdParsing()) { + throw new IllegalArgumentException("should use old parser only"); + } + + Path legacyRankProfilePath = ApplicationPackage.SEARCH_DEFINITIONS_DIR.append(schemaName); + for (NamedReader reader : applicationPackage.getFiles(legacyRankProfilePath, ".profile")) { + mediator.addRankProfileFile(schemaName, reader); + } + + Path rankProfilePath = ApplicationPackage.SCHEMAS_DIR.append(schemaName); + for (NamedReader reader : applicationPackage.getFiles(rankProfilePath, ".profile")) { + mediator.addRankProfileFile(schemaName, reader); + } + } + /** Parses the rank profile of the given reader and adds it to the rank profile registry for this schema. */ private void parseRankProfile(NamedReader reader, Schema schema) { + if (properties.featureFlags().experimentalSdParsing()) { + throw new IllegalArgumentException("should use new parser only"); + } try { SimpleCharStream stream = new SimpleCharStream(IOUtils.readAll(reader.getReader())); try { @@ -256,7 +315,19 @@ public class ApplicationBuilder { */ public Application build(boolean validate) { if (application != null) throw new IllegalStateException("Application already built"); - + if (properties.featureFlags().experimentalSdParsing()) { + var converter = new ConvertSchemaCollection(mediator, + documentTypeManager, + applicationPackage, + fileRegistry, + deployLogger, + properties, + rankProfileRegistry, + documentsOnly); + for (var schema : converter.convertToSchemas()) { + add(schema); + } + } application = new Application(applicationPackage, schemas, rankProfileRegistry, @@ -355,7 +426,7 @@ public class ApplicationBuilder { } /** - * Convenience factory methdd to create a SearchBuilder from multiple SD files. Only for testing. + * Convenience factory methods to create a SearchBuilder from multiple SD files. Only for testing. */ public static ApplicationBuilder createFromFiles(Collection<String> fileNames) throws IOException, ParseException { return createFromFiles(fileNames, new BaseDeployLogger()); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/DocumentModelBuilder.java b/config-model/src/main/java/com/yahoo/searchdefinition/DocumentModelBuilder.java index a29d66dc8f2..0db810d5933 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/DocumentModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/DocumentModelBuilder.java @@ -34,6 +34,7 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.IdentityHashMap; import java.util.LinkedHashSet; import java.util.LinkedList; import java.util.List; @@ -155,9 +156,9 @@ public class DocumentModelBuilder { private static void addSearchField(SDField field, SearchDef searchDef) { SearchField searchField = - new SearchField(field, - field.getIndices().containsKey(field.getName()) && field.getIndices().get(field.getName()).getType().equals(Index.Type.VESPA), - field.getAttributes().containsKey(field.getName())); + new SearchField(field, + field.getIndices().containsKey(field.getName()) && field.getIndices().get(field.getName()).getType().equals(Index.Type.VESPA), + field.getAttributes().containsKey(field.getName())); searchDef.add(searchField); // Add field to views @@ -252,9 +253,13 @@ public class DocumentModelBuilder { } else if (type instanceof StructDataType) { StructDataType dt = (StructDataType) type; for (com.yahoo.document.Field field : dt.getFields()) { - if (field.getDataType() != type) { - // XXX deprecated: - field.setDataType(resolveTemporariesRecurse(field.getDataType(), repo, docs, replacements)); + var ft = field.getDataType(); + if (ft != type) { + var newft = resolveTemporariesRecurse(ft, repo, docs, replacements); + if (ft != newft) { + // XXX deprecated: + field.setDataType(newft); + } } } } @@ -298,89 +303,6 @@ public class DocumentModelBuilder { return null; } - @SuppressWarnings("deprecation") - private static void specialHandleAnnotationReference(NewDocumentType docType, Field field) { - DataType fieldType = specialHandleAnnotationReferenceRecurse(docType, field.getName(), field.getDataType()); - if (fieldType == null) { - return; - } - field.setDataType(fieldType); // XXX deprecated - } - - private static DataType specialHandleAnnotationReferenceRecurse(NewDocumentType docType, String fieldName, - DataType dataType) { - if (dataType instanceof TemporaryAnnotationReferenceDataType) { - TemporaryAnnotationReferenceDataType refType = (TemporaryAnnotationReferenceDataType)dataType; - if (refType.getId() != 0) { - return null; - } - AnnotationType target = docType.getAnnotationType(refType.getTarget()); - if (target == null) { - throw new RetryLaterException("Annotation '" + refType.getTarget() + "' in reference '" + fieldName + - "' does not exist."); - } - dataType = new AnnotationReferenceDataType(target); - addType(docType, dataType); - return dataType; - } - else if (dataType instanceof MapDataType) { - MapDataType t = (MapDataType)dataType; - DataType valueType = specialHandleAnnotationReferenceRecurse(docType, fieldName, t.getValueType()); - if (valueType == null) { - return null; - } - var mapType = new MapDataType(t.getKeyType(), valueType, t.getId()); - addType(docType, mapType); - return mapType; - } - else if (dataType instanceof ArrayDataType) { - ArrayDataType t = (ArrayDataType) dataType; - DataType nestedType = specialHandleAnnotationReferenceRecurse(docType, fieldName, t.getNestedType()); - if (nestedType == null) { - return null; - } - var lstType = new ArrayDataType(nestedType, t.getId()); - addType(docType, lstType); - return lstType; - } - else if (dataType instanceof WeightedSetDataType) { - WeightedSetDataType t = (WeightedSetDataType) dataType; - DataType nestedType = specialHandleAnnotationReferenceRecurse(docType, fieldName, t.getNestedType()); - if (nestedType == null) { - return null; - } - boolean c = t.createIfNonExistent(); - boolean r = t.removeIfZero(); - var lstType = new WeightedSetDataType(nestedType, c, r, t.getId()); - addType(docType, lstType); - return lstType; - } - return null; - } - - private static StructDataType handleStruct(NewDocumentType dt, SDDocumentType type) { - StructDataType s = new StructDataType(type.getName()); - for (Field f : type.getDocumentType().contentStruct().getFieldsThisTypeOnly()) { - specialHandleAnnotationReference(dt, f); - s.addField(f); - } - for (StructDataType inherited : type.getDocumentType().contentStruct().getInheritedTypes()) { - s.inherit(inherited); - } - extractNestedTypes(dt, s); - addType(dt, s); - return s; - } - - private static StructDataType handleStruct(NewDocumentType dt, StructDataType s) { - for (Field f : s.getFieldsThisTypeOnly()) { - specialHandleAnnotationReference(dt, f); - } - extractNestedTypes(dt, s); - addType(dt, s); - return s; - } - private static boolean anyParentsHavePayLoad(SDAnnotationType sa, SDDocumentType sdoc) { if (sa.getInherits() != null) { AnnotationType tmp = sdoc.findAnnotation(sa.getInherits()); @@ -391,8 +313,6 @@ public class DocumentModelBuilder { } private NewDocumentType convert(SDDocumentType sdoc) { - Map<AnnotationType, String> annotationInheritance = new HashMap<>(); - Map<StructDataType, String> structInheritance = new HashMap<>(); NewDocumentType dt = new NewDocumentType(new NewDocumentType.Name(sdoc.getName()), sdoc.getDocumentType().contentStruct(), sdoc.getFieldSets(), @@ -400,63 +320,223 @@ public class DocumentModelBuilder { convertTemporaryImportedFieldsToNames(sdoc.getTemporaryImportedFields())); for (SDDocumentType n : sdoc.getInheritedTypes()) { NewDocumentType.Name name = new NewDocumentType.Name(n.getName()); - NewDocumentType inherited = model.getDocumentManager().getDocumentType(name); - if (inherited != null) { - dt.inherit(inherited); - } - } - for (SDDocumentType type : sdoc.getTypes()) { - if (type.isStruct()) { - handleStruct(dt, type); - } else { - throw new IllegalArgumentException("Data type '" + sdoc.getName() + "' is not a struct => tostring='" + sdoc.toString() + "'."); - } - } - for (SDDocumentType type : sdoc.getTypes()) { - for (SDDocumentType proxy : type.getInheritedTypes()) { - var inherited = dt.getDataTypeRecursive(proxy.getName()); - var converted = (StructDataType) dt.getDataType(type.getName()); - converted.inherit((StructDataType) inherited); + NewDocumentType inherited = model.getDocumentManager().getDocumentType(name); + if (inherited != null) { + dt.inherit(inherited); } } - for (AnnotationType annotation : sdoc.getAnnotations().values()) { - dt.add(annotation); + var extractor = new TypeExtractor(dt); + extractor.extract(sdoc); + return dt; + } + + static class TypeExtractor { + private final NewDocumentType targetDt; + Map<AnnotationType, String> annotationInheritance = new HashMap<>(); + Map<StructDataType, String> structInheritance = new HashMap<>(); + private final Map<Object, Object> inProgress = new IdentityHashMap<>(); + TypeExtractor(NewDocumentType target) { + this.targetDt = target; } - for (AnnotationType annotation : sdoc.getAnnotations().values()) { - SDAnnotationType sa = (SDAnnotationType) annotation; - if (annotation.getInheritedTypes().isEmpty() && (sa.getInherits() != null) ) { - annotationInheritance.put(annotation, sa.getInherits()); + + void extract(SDDocumentType sdoc) { + for (SDDocumentType type : sdoc.getTypes()) { + if (type.isStruct()) { + handleStruct(type); + } else { + throw new IllegalArgumentException("Data type '" + type.getName() + "' is not a struct => tostring='" + type.toString() + "'."); + } + } + for (SDDocumentType type : sdoc.getTypes()) { + for (SDDocumentType proxy : type.getInheritedTypes()) { + var inherited = targetDt.getDataTypeRecursive(proxy.getName()); + var converted = (StructDataType) targetDt.getDataType(type.getName()); + converted.inherit((StructDataType) inherited); + } } - if (annotation.getDataType() == null) { - if (sa.getSdDocType() != null) { - StructDataType s = handleStruct(dt, sa.getSdDocType()); - annotation.setDataType(s); - if ((sa.getInherits() != null)) { + for (AnnotationType annotation : sdoc.getAnnotations().values()) { + targetDt.add(annotation); + } + for (AnnotationType annotation : sdoc.getAnnotations().values()) { + SDAnnotationType sa = (SDAnnotationType) annotation; + if (annotation.getInheritedTypes().isEmpty() && (sa.getInherits() != null) ) { + annotationInheritance.put(annotation, sa.getInherits()); + } + if (annotation.getDataType() == null) { + if (sa.getSdDocType() != null) { + StructDataType s = handleStruct(sa.getSdDocType()); + annotation.setDataType(s); + if ((sa.getInherits() != null)) { + structInheritance.put(s, "annotation."+sa.getInherits()); + } + } else if (sa.getInherits() != null) { + StructDataType s = new StructDataType("annotation."+annotation.getName()); + if (anyParentsHavePayLoad(sa, sdoc)) { + annotation.setDataType(s); + addType(s); + } structInheritance.put(s, "annotation."+sa.getInherits()); } - } else if (sa.getInherits() != null) { - StructDataType s = new StructDataType("annotation."+annotation.getName()); - if (anyParentsHavePayLoad(sa, sdoc)) { - annotation.setDataType(s); - addType(dt, s); + } + } + for (Map.Entry<AnnotationType, String> e : annotationInheritance.entrySet()) { + e.getKey().inherit(targetDt.getAnnotationType(e.getValue())); + } + for (Map.Entry<StructDataType, String> e : structInheritance.entrySet()) { + StructDataType s = (StructDataType)targetDt.getDataType(e.getValue()); + if (s != null) { + e.getKey().inherit(s); + } + } + handleStruct(sdoc.getDocumentType().contentStruct()); + + extractDataTypesFromFields(sdoc.fieldSet()); + } + + private void extractDataTypesFromFields(Collection<Field> fields) { + for (Field f : fields) { + DataType type = f.getDataType(); + if (testAddType(type)) { + extractNestedTypes(type); + addType(type); + } + } + } + + private void extractNestedTypes(DataType type) { + if (inProgress.containsKey(type)) { + return; + } + inProgress.put(type, this); + if (type instanceof StructDataType) { + StructDataType tmp = (StructDataType) type; + extractDataTypesFromFields(tmp.getFieldsThisTypeOnly()); + } else if (type instanceof DocumentType) { + throw new IllegalArgumentException("Can not handle nested document definitions. In document type '" + targetDt.getName().toString() + + "', we can not define document type '" + type.toString()); + } else if (type instanceof CollectionDataType) { + CollectionDataType tmp = (CollectionDataType) type; + extractNestedTypes(tmp.getNestedType()); + addType(tmp.getNestedType()); + } else if (type instanceof MapDataType) { + MapDataType tmp = (MapDataType) type; + extractNestedTypes(tmp.getKeyType()); + extractNestedTypes(tmp.getValueType()); + addType(tmp.getKeyType()); + addType(tmp.getValueType()); + } else if (type instanceof TemporaryAnnotationReferenceDataType) { + throw new IllegalArgumentException(type.toString()); + } + } + + private boolean testAddType(DataType type) { return internalAddType(type, true); } + + private boolean addType(DataType type) { return internalAddType(type, false); } + + private boolean internalAddType(DataType type, boolean dryRun) { + DataType oldType = targetDt.getDataTypeRecursive(type.getId()); + if (oldType == null) { + if ( ! dryRun) { + targetDt.add(type); + } + return true; + } else if ((type instanceof StructDataType) && (oldType instanceof StructDataType)) { + StructDataType s = (StructDataType) type; + StructDataType os = (StructDataType) oldType; + if ((os.getFieldCount() == 0) && (s.getFieldCount() > os.getFieldCount())) { + if ( ! dryRun) { + targetDt.replace(type); } - structInheritance.put(s, "annotation."+sa.getInherits()); + return true; } } + return false; } - for (Map.Entry<AnnotationType, String> e : annotationInheritance.entrySet()) { - e.getKey().inherit(dt.getAnnotationType(e.getValue())); + + + @SuppressWarnings("deprecation") + private void specialHandleAnnotationReference(Field field) { + DataType fieldType = specialHandleAnnotationReferenceRecurse(field.getName(), field.getDataType()); + if (fieldType == null) { + return; + } + field.setDataType(fieldType); // XXX deprecated + } + + private DataType specialHandleAnnotationReferenceRecurse(String fieldName, + DataType dataType) { + if (dataType instanceof TemporaryAnnotationReferenceDataType) { + TemporaryAnnotationReferenceDataType refType = (TemporaryAnnotationReferenceDataType)dataType; + if (refType.getId() != 0) { + return null; + } + AnnotationType target = targetDt.getAnnotationType(refType.getTarget()); + if (target == null) { + throw new RetryLaterException("Annotation '" + refType.getTarget() + "' in reference '" + fieldName + + "' does not exist."); + } + dataType = new AnnotationReferenceDataType(target); + addType(dataType); + return dataType; + } + else if (dataType instanceof MapDataType) { + MapDataType t = (MapDataType)dataType; + DataType valueType = specialHandleAnnotationReferenceRecurse(fieldName, t.getValueType()); + if (valueType == null) { + return null; + } + var mapType = new MapDataType(t.getKeyType(), valueType, t.getId()); + addType(mapType); + return mapType; + } + else if (dataType instanceof ArrayDataType) { + ArrayDataType t = (ArrayDataType) dataType; + DataType nestedType = specialHandleAnnotationReferenceRecurse(fieldName, t.getNestedType()); + if (nestedType == null) { + return null; + } + var lstType = new ArrayDataType(nestedType, t.getId()); + addType(lstType); + return lstType; + } + else if (dataType instanceof WeightedSetDataType) { + WeightedSetDataType t = (WeightedSetDataType) dataType; + DataType nestedType = specialHandleAnnotationReferenceRecurse(fieldName, t.getNestedType()); + if (nestedType == null) { + return null; + } + boolean c = t.createIfNonExistent(); + boolean r = t.removeIfZero(); + var lstType = new WeightedSetDataType(nestedType, c, r, t.getId()); + addType(lstType); + return lstType; + } + return null; } - for (Map.Entry<StructDataType, String> e : structInheritance.entrySet()) { - StructDataType s = (StructDataType)dt.getDataType(e.getValue()); - if (s != null) { - e.getKey().inherit(s); + + private StructDataType handleStruct(SDDocumentType type) { + StructDataType s = new StructDataType(type.getName()); + for (Field f : type.getDocumentType().contentStruct().getFieldsThisTypeOnly()) { + specialHandleAnnotationReference(f); + s.addField(f); + } + for (StructDataType inherited : type.getDocumentType().contentStruct().getInheritedTypes()) { + s.inherit(inherited); } + extractNestedTypes(s); + addType(s); + return s; } - handleStruct(dt, sdoc.getDocumentType().contentStruct()); - extractDataTypesFromFields(dt, sdoc.fieldSet()); - return dt; + private StructDataType handleStruct(StructDataType s) { + for (Field f : s.getFieldsThisTypeOnly()) { + specialHandleAnnotationReference(f); + } + extractNestedTypes(s); + addType(s); + return s; + } + } private static Set<NewDocumentType.Name> convertDocumentReferencesToNames(Optional<DocumentReferences> documentReferences) { @@ -464,9 +544,9 @@ public class DocumentModelBuilder { return Set.of(); } return documentReferences.get().referenceMap().values().stream() - .map(documentReference -> documentReference.targetSearch().getDocument()) - .map(documentType -> new NewDocumentType.Name(documentType.getName())) - .collect(Collectors.toCollection(() -> new LinkedHashSet<>())); + .map(documentReference -> documentReference.targetSearch().getDocument()) + .map(documentType -> new NewDocumentType.Name(documentType.getName())) + .collect(Collectors.toCollection(() -> new LinkedHashSet<>())); } private static Set<String> convertTemporaryImportedFieldsToNames(TemporaryImportedFields importedFields) { @@ -476,62 +556,6 @@ public class DocumentModelBuilder { return Collections.unmodifiableSet(importedFields.fields().keySet()); } - private static void extractDataTypesFromFields(NewDocumentType dt, Collection<Field> fields) { - for (Field f : fields) { - DataType type = f.getDataType(); - if (testAddType(dt, type)) { - extractNestedTypes(dt, type); - addType(dt, type); - } - } - } - - private static void extractNestedTypes(NewDocumentType dt, DataType type) { - if (type instanceof StructDataType) { - StructDataType tmp = (StructDataType) type; - extractDataTypesFromFields(dt, tmp.getFieldsThisTypeOnly()); - } else if (type instanceof DocumentType) { - throw new IllegalArgumentException("Can not handle nested document definitions. In document type '" + dt.getName().toString() + - "', we can not define document type '" + type.toString()); - } else if (type instanceof CollectionDataType) { - CollectionDataType tmp = (CollectionDataType) type; - extractNestedTypes(dt, tmp.getNestedType()); - addType(dt, tmp.getNestedType()); - } else if (type instanceof MapDataType) { - MapDataType tmp = (MapDataType) type; - extractNestedTypes(dt, tmp.getKeyType()); - extractNestedTypes(dt, tmp.getValueType()); - addType(dt, tmp.getKeyType()); - addType(dt, tmp.getValueType()); - } else if (type instanceof TemporaryAnnotationReferenceDataType) { - throw new IllegalArgumentException(type.toString()); - } - } - - private static boolean testAddType(NewDocumentType dt, DataType type) { return internalAddType(dt, type, true); } - - private static boolean addType(NewDocumentType dt, DataType type) { return internalAddType(dt, type, false); } - - private static boolean internalAddType(NewDocumentType dt, DataType type, boolean dryRun) { - DataType oldType = dt.getDataTypeRecursive(type.getId()); - if (oldType == null) { - if ( ! dryRun) { - dt.add(type); - } - return true; - } else if ((type instanceof StructDataType) && (oldType instanceof StructDataType)) { - StructDataType s = (StructDataType) type; - StructDataType os = (StructDataType) oldType; - if ((os.getFieldCount() == 0) && (s.getFieldCount() > os.getFieldCount())) { - if ( ! dryRun) { - dt.replace(type); - } - return true; - } - } - return false; - } - public static class RetryLaterException extends IllegalArgumentException { public RetryLaterException(String message) { super(message); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForStructs.java b/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForStructs.java index 16bf37902f5..5e5623e2319 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForStructs.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForStructs.java @@ -35,17 +35,32 @@ public class FieldOperationApplierForStructs extends FieldOperationApplier { Iterator<Field> fields = anyType.fieldIterator(); while (fields.hasNext()) { SDField field = (SDField) fields.next(); - DataType structUsedByField = field.getFirstStructRecursive(); - if (structUsedByField == null) { - continue; - } - if (structUsedByField.getName().equals(structType.getName())) { - //this field is using this type!! - field.populateWithStructFields(sdoc, field.getName(), field.getDataType(), 0); - field.populateWithStructMatching(sdoc, field.getDataType(), field.getMatching()); - } + maybePopulateField(sdoc, field, structType); } } } + private void maybePopulateField(SDDocumentType sdoc, SDField field, SDDocumentType structType) { + DataType structUsedByField = field.getFirstStructRecursive(); + if (structUsedByField == null) { + return; + } + if (structUsedByField.getName().equals(structType.getName())) { + //this field is using this type!! + field.populateWithStructFields(sdoc, field.getName(), field.getDataType(), 0); + field.populateWithStructMatching(sdoc, field.getDataType(), field.getMatching()); + } + } + + public void processSchemaFields(Schema schema) { + var sdoc = schema.getDocument(); + if (sdoc == null) return; + for (SDDocumentType type : sdoc.getAllTypes()) { + if (type.isStruct()) { + for (SDField field : schema.allExtraFields()) { + maybePopulateField(sdoc, field, type); + } + } + } + } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index e419b8c93a7..e7b7c92ca2f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -745,8 +745,8 @@ public class RankProfile implements Cloneable { public RankingExpressionFunction addFunction(ExpressionFunction function, boolean inline) { RankingExpressionFunction rankingExpressionFunction = new RankingExpressionFunction(function, inline); if (functions.containsKey(function.getName())) { - deployLogger.log(Level.WARNING, "Function '" + function.getName() + "' replaces a previous function " + - "with the same name in rank profile '" + this.name + "'"); + deployLogger.log(Level.WARNING, "Function '" + function.getName() + "' is defined twice " + + "in rank profile '" + this.name + "'"); } functions.put(function.getName(), rankingExpressionFunction); allFunctionsCached = null; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java index d9bdf5dc917..0d1222e737b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java @@ -9,6 +9,7 @@ import com.yahoo.document.MapDataType; import com.yahoo.document.StructDataType; import com.yahoo.document.TemporaryStructuredDataType; import com.yahoo.document.TensorDataType; +import com.yahoo.document.WeightedSetDataType; import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; import com.yahoo.language.simple.SimpleLinguistics; @@ -186,6 +187,10 @@ public class SDField extends Field implements TypedKey, FieldOperationContainer, ": Dense tensor dimensions must have a size"); addQueryCommand("type " + type); } + else if (dataType instanceof WeightedSetDataType) { + var nested = ((WeightedSetDataType) dataType).getNestedType().getName(); + addQueryCommand("type WeightedSet<" + nested + ">"); + } else { addQueryCommand("type " + dataType.getName()); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedFields.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedFields.java index dd697d51363..3cf5628d282 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedFields.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedFields.java @@ -138,6 +138,7 @@ public class ConvertParsedFields { if (indexing.isPresent()) { field.setIndexingScript(indexing.get().script()); } + parsed.getWeight().ifPresent(value -> field.setWeight(value)); parsed.getStemming().ifPresent(value -> field.setStemming(value)); parsed.getNormalizing().ifPresent(value -> convertNormalizing(field, value)); for (var attribute : parsed.getAttributes()) { @@ -290,20 +291,23 @@ public class ConvertParsedFields { schema.addIndex(index); } - void convertStructDeclaration(Schema schema, SDDocumentType document, ParsedStruct parsed) { + SDDocumentType convertStructDeclaration(Schema schema, ParsedStruct parsed) { // TODO - can we cleanup this mess var structProxy = new SDDocumentType(parsed.name(), schema); structProxy.setStruct(context.resolveStruct(parsed)); - for (var structField : parsed.getFields()) { - var fieldType = context.resolveType(structField.getType()); - var field = new SDField(structProxy, structField.name(), fieldType); - convertCommonFieldSettings(field, structField); + for (var parsedField : parsed.getFields()) { + var fieldType = context.resolveType(parsedField.getType()); + var field = new SDField(structProxy, parsedField.name(), fieldType); + convertCommonFieldSettings(field, parsedField); structProxy.addField(field); + if (parsedField.hasIdOverride()) { + structProxy.setFieldId(field, parsedField.idOverride()); + } } for (String inherit : parsed.getInherited()) { structProxy.inherit(new DataTypeName(inherit)); } - document.addType(structProxy); + return structProxy; } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedTypes.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedTypes.java index 5a83b4d8a0e..e67c1ac8275 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedTypes.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedTypes.java @@ -37,10 +37,12 @@ public class ConvertParsedTypes { this.docMan = docMan; } - public void convert() { + public void convert(boolean andRegister) { startDataTypes(); fillDataTypes(); - registerDataTypes(); + if (andRegister) { + registerDataTypes(); + } } private Map<String, DocumentType> documentsFromSchemas = new HashMap<>(); @@ -55,8 +57,8 @@ public class ConvertParsedTypes { for (var schema : orderedInput) { var doc = schema.getDocument(); for (var struct : doc.getStructs()) { - var dt = new StructDataType(struct.name()); String structId = doc.name() + "->" + struct.name(); + var dt = new StructDataType(struct.name()); structsFromSchemas.put(structId, dt); } for (var annotation : doc.getAnnotations()) { @@ -80,10 +82,20 @@ public class ConvertParsedTypes { for (var struct : doc.getStructs()) { String structId = doc.name() + "->" + struct.name(); var toFill = structsFromSchemas.get(structId); + // evil ugliness for (ParsedField field : struct.getFields()) { - var t = resolveFromContext(field.getType(), doc); - var f = new com.yahoo.document.Field(field.name(), t); - toFill.addField(f); + if (! field.hasIdOverride()) { + var t = resolveFromContext(field.getType(), doc); + var f = new com.yahoo.document.Field(field.name(), t); + toFill.addField(f); + } + } + for (ParsedField field : struct.getFields()) { + if (field.hasIdOverride()) { + var t = resolveFromContext(field.getType(), doc); + var f = new com.yahoo.document.Field(field.name(), field.idOverride(), t); + toFill.addField(f); + } } for (String inherit : struct.getInherited()) { var parent = findStructFromSchemas(inherit, doc); @@ -100,7 +112,9 @@ public class ConvertParsedTypes { var toFill = structsFromSchemas.get(structId); for (ParsedField field : struct.getFields()) { var t = resolveFromContext(field.getType(), doc); - var f = new com.yahoo.document.Field(field.name(), t); + var f = field.hasIdOverride() + ? new com.yahoo.document.Field(field.name(), field.idOverride(), t) + : new com.yahoo.document.Field(field.name(), t); toFill.addField(f); } at.setDataType(toFill); @@ -116,8 +130,11 @@ public class ConvertParsedTypes { for (var docField : doc.getFields()) { String name = docField.name(); var t = resolveFromContext(docField.getType(), doc); - var f = new com.yahoo.document.Field(name, t); + var f = new com.yahoo.document.Field(docField.name(), t); docToFill.addField(f); + if (docField.hasIdOverride()) { + f.setId(docField.idOverride(), docToFill); + } inDocFields.add(name); } fieldSets.put("[document]", inDocFields); @@ -171,18 +188,6 @@ public class ConvertParsedTypes { throw new IllegalArgumentException("conflicting values for struct " + name + " in " +doc); } } - if (found == null) { - // TODO: be more restrictive here, but we need something - // for imported fields. For now, fall back to looking for - // struct in any schema. - for (var schema : orderedInput) { - for (var struct : schema.getDocument().getStructs()) { - if (struct.name().equals(name)) { - return struct; - } - } - } - } return found; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertSchemaCollection.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertSchemaCollection.java index fb0003ca4f9..21a68744c19 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertSchemaCollection.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertSchemaCollection.java @@ -1,25 +1,43 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.parser; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.application.api.DeployLogger; +import com.yahoo.config.application.api.FileRegistry; +import com.yahoo.config.model.api.ModelContext; +import com.yahoo.config.model.application.provider.BaseDeployLogger; +import com.yahoo.config.model.application.provider.MockFileRegistry; +import com.yahoo.config.model.deploy.TestProperties; +import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.document.DataType; +import com.yahoo.document.DataTypeName; import com.yahoo.document.DocumentType; import com.yahoo.document.DocumentTypeManager; +import com.yahoo.document.PositionDataType; import com.yahoo.document.ReferenceDataType; import com.yahoo.document.StructDataType; -import com.yahoo.document.PositionDataType; import com.yahoo.document.WeightedSetDataType; import com.yahoo.document.annotation.AnnotationReferenceDataType; import com.yahoo.document.annotation.AnnotationType; +import com.yahoo.searchdefinition.DefaultRankProfile; +import com.yahoo.searchdefinition.DocumentOnlySchema; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Schema; +import com.yahoo.searchdefinition.UnrankedRankProfile; +import com.yahoo.searchdefinition.document.SDDocumentType; +import com.yahoo.searchdefinition.document.SDField; +import com.yahoo.searchdefinition.document.TemporaryImportedField; +import com.yahoo.searchdefinition.document.annotation.SDAnnotationType; +import com.yahoo.searchdefinition.parser.ConvertParsedTypes.TypeResolver; +import com.yahoo.vespa.documentmodel.DocumentSummary; +import com.yahoo.vespa.documentmodel.SummaryField; import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; import java.util.List; -import java.util.Map; +import java.util.Optional; /** * Class converting a collection of schemas from the intermediate format. - * For now only conversion to DocumentType (with contents). * * @author arnej27959 **/ @@ -28,12 +46,45 @@ public class ConvertSchemaCollection { private final IntermediateCollection input; private final List<ParsedSchema> orderedInput = new ArrayList<>(); private final DocumentTypeManager docMan; + private final ApplicationPackage applicationPackage; + private final FileRegistry fileRegistry; + private final DeployLogger deployLogger; + private final ModelContext.Properties properties; + private final RankProfileRegistry rankProfileRegistry; + private final boolean documentsOnly; + + // for unit test + ConvertSchemaCollection(IntermediateCollection input, + DocumentTypeManager documentTypeManager) + { + this(input, documentTypeManager, + MockApplicationPackage.createEmpty(), + new MockFileRegistry(), + new BaseDeployLogger(), + new TestProperties(), + new RankProfileRegistry(), + true); + } public ConvertSchemaCollection(IntermediateCollection input, - DocumentTypeManager documentTypeManager) + DocumentTypeManager documentTypeManager, + ApplicationPackage applicationPackage, + FileRegistry fileRegistry, + DeployLogger deployLogger, + ModelContext.Properties properties, + RankProfileRegistry rankProfileRegistry, + boolean documentsOnly) { this.input = input; this.docMan = documentTypeManager; + this.applicationPackage = applicationPackage; + this.fileRegistry = fileRegistry; + this.deployLogger = deployLogger; + this.properties = properties; + this.rankProfileRegistry = rankProfileRegistry; + this.documentsOnly = documentsOnly; + + input.resolveInternalConnections(); order(); pushTypesToDocuments(); } @@ -64,8 +115,161 @@ public class ConvertSchemaCollection { } } + private ConvertParsedTypes typeConverter; + public void convertTypes() { - var converter = new ConvertParsedTypes(orderedInput, docMan); - converter.convert(); + typeConverter = new ConvertParsedTypes(orderedInput, docMan); + typeConverter.convert(true); } + + public List<Schema> convertToSchemas() { + typeConverter = new ConvertParsedTypes(orderedInput, docMan); + typeConverter.convert(false); + var resultList = new ArrayList<Schema>(); + for (var parsed : orderedInput) { + Optional<String> inherited; + var inheritList = parsed.getInherited(); + if (inheritList.size() == 0) { + inherited = Optional.empty(); + } else if (inheritList.size() == 1) { + inherited = Optional.of(inheritList.get(0)); + } else { + throw new IllegalArgumentException("schema " + parsed.name() + "cannot inherit more than once"); + } + Schema schema = parsed.getDocumentWithoutSchema() + ? new DocumentOnlySchema(applicationPackage, fileRegistry, deployLogger, properties) + : new Schema(parsed.name(), applicationPackage, inherited, fileRegistry, deployLogger, properties); + convertSchema(schema, parsed); + resultList.add(schema); + } + return resultList; + } + + private void convertAnnotation(Schema schema, SDDocumentType document, ParsedAnnotation parsed, ConvertParsedFields fieldConverter) { + var type = new SDAnnotationType(parsed.name()); + for (String inherit : parsed.getInherited()) { + type.inherit(inherit); + } + var payload = parsed.getStruct(); + if (payload.isPresent()) { + var struct = fieldConverter.convertStructDeclaration(schema, payload.get()); + type = new SDAnnotationType(parsed.name(), struct, type.getInherits()); + // WTF? + struct.setStruct(null); + } + document.addAnnotation(type); + } + + private void convertDocument(Schema schema, ParsedDocument parsed, + ConvertParsedFields fieldConverter) + { + SDDocumentType document = new SDDocumentType(parsed.name()); + for (String inherit : parsed.getInherited()) { + document.inherit(new DataTypeName(inherit)); + } + for (var struct : parsed.getStructs()) { + var structProxy = fieldConverter.convertStructDeclaration(schema, struct); + document.addType(structProxy); + } + for (var annotation : parsed.getAnnotations()) { + convertAnnotation(schema, document, annotation, fieldConverter); + } + for (var field : parsed.getFields()) { + var sdf = fieldConverter.convertDocumentField(schema, document, field); + if (field.hasIdOverride()) { + document.setFieldId(sdf, field.idOverride()); + } + } + schema.addDocument(document); + } + + private void convertDocumentSummary(Schema schema, ParsedDocumentSummary parsed, TypeResolver typeContext) { + var docsum = new DocumentSummary(parsed.name(), schema); + var inheritList = parsed.getInherited(); + if (inheritList.size() == 1) { + docsum.setInherited(inheritList.get(0)); + } else if (inheritList.size() != 0) { + throw new IllegalArgumentException("document-summary "+parsed.name()+" cannot inherit more than once"); + } + if (parsed.getFromDisk()) { + docsum.setFromDisk(true); + } + if (parsed.getOmitSummaryFeatures()) { + docsum.setOmitSummaryFeatures(true); + } + for (var parsedField : parsed.getSummaryFields()) { + DataType dataType = typeContext.resolveType(parsedField.getType()); + var summaryField = new SummaryField(parsedField.name(), dataType); + // XXX does not belong here: + summaryField.setVsmCommand(SummaryField.VsmCommand.FLATTENSPACE); + ConvertParsedFields.convertSummaryFieldSettings(summaryField, parsedField); + docsum.add(summaryField); + } + schema.addSummary(docsum); + } + + private void convertImportField(Schema schema, ParsedSchema.ImportedField f) { + // needs rethinking + var importedFields = schema.temporaryImportedFields().get(); + if (importedFields.hasField(f.asFieldName)) { + throw new IllegalArgumentException("For schema '" + schema.getName() + + "', import field as '" + f.asFieldName + + "': Field already imported"); + } + importedFields.add(new TemporaryImportedField(f.asFieldName, f.refFieldName, f.foreignFieldName)); + } + + private void convertFieldSet(Schema schema, ParsedFieldSet parsed) { + String setName = parsed.name(); + for (String field : parsed.getFieldNames()) { + schema.fieldSets().addUserFieldSetItem(setName, field); + } + for (String command : parsed.getQueryCommands()) { + schema.fieldSets().userFieldSets().get(setName).queryCommands().add(command); + } + if (parsed.getMatchSettings().isPresent()) { + // same ugliness as SDParser.jj used to have: + var tmp = new SDField(setName, DataType.STRING); + ConvertParsedFields.convertMatchSettings(tmp, parsed.matchSettings()); + schema.fieldSets().userFieldSets().get(setName).setMatching(tmp.getMatching()); + } + } + + private void convertSchema(Schema schema, ParsedSchema parsed) { + if (parsed.hasStemming()) { + schema.setStemming(parsed.getStemming()); + } + parsed.getRawAsBase64().ifPresent(value -> schema.enableRawAsBase64(value)); + var typeContext = typeConverter.makeContext(parsed.getDocument()); + var fieldConverter = new ConvertParsedFields(typeContext); + convertDocument(schema, parsed.getDocument(), fieldConverter); + for (var field : parsed.getFields()) { + fieldConverter.convertExtraField(schema, field); + } + for (var index : parsed.getIndexes()) { + fieldConverter.convertExtraIndex(schema, index); + } + for (var docsum : parsed.getDocumentSummaries()) { + convertDocumentSummary(schema, docsum, typeContext); + } + for (var importedField : parsed.getImportedFields()) { + convertImportField(schema, importedField); + } + for (var fieldSet : parsed.getFieldSets()) { + convertFieldSet(schema, fieldSet); + } + for (var rankingConstant : parsed.getRankingConstants()) { + schema.rankingConstants().add(rankingConstant); + } + for (var onnxModel : parsed.getOnnxModels()) { + schema.onnxModels().add(onnxModel); + } + rankProfileRegistry.add(new DefaultRankProfile(schema, rankProfileRegistry, schema.rankingConstants())); + rankProfileRegistry.add(new UnrankedRankProfile(schema, rankProfileRegistry, schema.rankingConstants())); + var rankConverter = new ConvertParsedRanking(rankProfileRegistry); + for (var rankProfile : parsed.getRankProfiles()) { + rankConverter.convertRankProfile(schema, rankProfile); + } + } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/InheritanceResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/InheritanceResolver.java index d9132d3aa24..4d011c1b596 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/InheritanceResolver.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/InheritanceResolver.java @@ -15,6 +15,7 @@ public class InheritanceResolver { private final Map<String, ParsedSchema> parsedSchemas; private final Map<String, ParsedDocument> parsedDocs = new HashMap<>(); + private final Map<String, ParsedSchema> schemaForDocs = new HashMap<>(); public InheritanceResolver(Map<String, ParsedSchema> parsedSchemas) { this.parsedSchemas = parsedSchemas; @@ -24,7 +25,7 @@ public class InheritanceResolver { String name = schema.name(); if (seen.contains(name)) { seen.add(name); - throw new IllegalArgumentException("Inheritance cycle for schemas: " + + throw new IllegalArgumentException("Inheritance/reference cycle for schemas: " + String.join(" -> ", seen)); } seen.add(name); @@ -64,9 +65,13 @@ public class InheritanceResolver { if (old != null) { throw new IllegalArgumentException("duplicate document declaration for " + doc.name()); } + schemaForDocs.put(doc.name(), schema); for (String docInherit : doc.getInherited()) { schema.inheritByDocument(docInherit); } + for (String docReferenced : doc.getReferencedDocuments()) { + schema.inheritByDocument(docReferenced); + } } for (ParsedDocument doc : parsedDocs.values()) { for (String inherit : doc.getInherited()) { @@ -76,13 +81,20 @@ public class InheritanceResolver { } doc.resolveInherit(inherit, parentDoc); } + for (String docRefName : doc.getReferencedDocuments()) { + var refDoc = parsedDocs.get(docRefName); + if (refDoc == null) { + throw new IllegalArgumentException("document " + doc.name() + " references unavailable document " + docRefName); + } + doc.resolveReferenced(refDoc); + } } for (ParsedSchema schema : parsedSchemas.values()) { - for (String inherit : schema.getInheritedByDocument()) { - var parent = parsedSchemas.get(inherit); + for (String docName : schema.getInheritedByDocument()) { + var parent = schemaForDocs.get(docName); assert(parent.hasDocument()); - assert(parent.getDocument().name().equals(inherit)); - schema.resolveInheritByDocument(inherit, parent); + assert(parent.getDocument().name().equals(docName)); + schema.resolveInheritByDocument(docName, parent); } } } @@ -91,7 +103,7 @@ public class InheritanceResolver { String name = document.name(); if (seen.contains(name)) { seen.add(name); - throw new IllegalArgumentException("Inheritance cycle for documents: " + + throw new IllegalArgumentException("Inheritance/reference cycle for documents: " + String.join(" -> ", seen)); } seen.add(name); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedDocument.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedDocument.java index 065b66e22b1..ed975238067 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedDocument.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedDocument.java @@ -1,7 +1,6 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.parser; - import java.util.ArrayList; import java.util.Collection; import java.util.LinkedHashMap; @@ -34,6 +33,19 @@ public class ParsedDocument extends ParsedBlock { ParsedStruct getStruct(String name) { return docStructs.get(name); } ParsedAnnotation getAnnotation(String name) { return docAnnotations.get(name); } + List<String> getReferencedDocuments() { + var result = new ArrayList<String>(); + for (var field : docFields.values()) { + var type = field.getType(); + if (type.getVariant() == ParsedType.Variant.DOC_REFERENCE) { + var docType = type.getReferencedDocumentType(); + assert(docType.getVariant() == ParsedType.Variant.DOCUMENT); + result.add(docType.name()); + } + } + return result; + } + void inherit(String other) { inherited.add(other); } void addField(ParsedField field) { @@ -54,6 +66,7 @@ public class ParsedDocument extends ParsedBlock { verifyThat(! docAnnotations.containsKey(annName), "already has annotation", annName); docAnnotations.put(annName, annotation); annotation.tagOwner(name()); + annotation.getStruct().ifPresent(s -> s.tagOwner(name())); } public String toString() { return "document " + name(); } @@ -64,5 +77,12 @@ public class ParsedDocument extends ParsedBlock { verifyThat(! resolvedInherits.containsKey(name), "double resolveInherit for", name); resolvedInherits.put(name, parsed); } + + void resolveReferenced(ParsedDocument parsed) { + // TODO - not really inheritance: + var old = resolvedInherits.put(parsed.name(), parsed); + assert(old == null || old == parsed); + } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedDocumentSummary.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedDocumentSummary.java index 08f4946a218..25adc6f134f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedDocumentSummary.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedDocumentSummary.java @@ -28,10 +28,11 @@ class ParsedDocumentSummary extends ParsedBlock { List<ParsedSummaryField> getSummaryFields() { return List.copyOf(fields.values()); } List<String> getInherited() { return List.copyOf(inherited); } - void addField(ParsedSummaryField field) { + ParsedSummaryField addField(ParsedSummaryField field) { String fieldName = field.name(); - verifyThat(! fields.containsKey(fieldName), "already has field", fieldName); - fields.put(fieldName, field); + // TODO disallow this on Vespa 8 + // verifyThat(! fields.containsKey(fieldName), "already has field", fieldName); + return fields.put(fieldName, field); } void setFromDisk(boolean value) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedField.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedField.java index 5ee73abc28d..ca876997dc6 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedField.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedField.java @@ -58,6 +58,7 @@ class ParsedField extends ParsedBlock { List<String> getQueryCommands() { return List.copyOf(queryCommands); } String lookupAliasedFrom(String alias) { return aliases.get(alias); } ParsedMatchSettings matchSettings() { return this.matchInfo; } + Optional<Integer> getWeight() { return Optional.ofNullable(weight); } Optional<Stemming> getStemming() { return Optional.ofNullable(stemming); } Optional<String> getNormalizing() { return Optional.ofNullable(normalizing); } Optional<ParsedIndexingOp> getIndexing() { return Optional.ofNullable(indexingOp); } @@ -121,10 +122,8 @@ class ParsedField extends ParsedBlock { void setStemming(Stemming stemming) { this.stemming = stemming; } void setWeight(int weight) { this.weight = weight; } - void addAttribute(ParsedAttribute attribute) { - String attrName = attribute.name(); - verifyThat(! attributes.containsKey(attrName), "already has attribute", attrName); - attributes.put(attrName, attribute); + ParsedAttribute attributeFor(String attrName) { + return attributes.computeIfAbsent(attrName, n -> new ParsedAttribute(n)); } void setIndexingOperation(ParsedIndexingOp idxOp) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedRankProfile.java index 0801b613530..f028685b71a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedRankProfile.java @@ -113,9 +113,10 @@ class ParsedRankProfile extends ParsedBlock { fieldsRankWeight.put(field, weight); } - void addFunction(ParsedRankFunction func) { - verifyThat(! functions.containsKey(func.name()), "already has function", func.name()); - functions.put(func.name(), func); + ParsedRankFunction addOrReplaceFunction(ParsedRankFunction func) { + // allowed with warning + // verifyThat(! functions.containsKey(func.name()), "already has function", func.name()); + return functions.put(func.name(), func); } void addMutateOperation(MutateOperation.Phase phase, String attrName, String operation) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedSchema.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedSchema.java index bcbf14d9398..599dd6e2a7a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedSchema.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedSchema.java @@ -9,6 +9,7 @@ import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; /** * This class holds the extracted information after parsing @@ -30,7 +31,7 @@ public class ParsedSchema extends ParsedBlock { } private boolean documentWithoutSchema = false; - private boolean rawAsBase64 = false; // TODO Vespa 8 flip default + private Boolean rawAsBase64 = null; private ParsedDocument myDocument = null; private Stemming defaultStemming = null; private final List<ImportedField> importedFields = new ArrayList<>(); @@ -53,7 +54,7 @@ public class ParsedSchema extends ParsedBlock { } boolean getDocumentWithoutSchema() { return documentWithoutSchema; } - boolean getRawAsBase64() { return rawAsBase64; } + Optional<Boolean> getRawAsBase64() { return Optional.ofNullable(rawAsBase64); } boolean hasDocument() { return myDocument != null; } ParsedDocument getDocument() { return myDocument; } boolean hasStemming() { return defaultStemming != null; } @@ -82,8 +83,9 @@ public class ParsedSchema extends ParsedBlock { void addDocument(ParsedDocument document) { verifyThat(myDocument == null, "already has", myDocument, "so cannot add", document); - verifyThat(name().equals(document.name()), - "schema " + name() + "can only contain document named " + name() + ", was: "+ document.name()); + // TODO - disallow? + // verifyThat(name().equals(document.name()), + // "schema " + name() + " can only contain document named " + name() + ", was: "+ document.name()); this.myDocument = document; } @@ -156,16 +158,15 @@ public class ParsedSchema extends ParsedBlock { verifyThat(name.equals(parsed.name()), "resolveInherit name mismatch for", name); verifyThat(! resolvedInherits.containsKey(name), "double resolveInherit for", name); resolvedInherits.put(name, parsed); - var old = allResolvedInherits.put(name, parsed); + var old = allResolvedInherits.put("schema " + name, parsed); verifyThat(old == null || old == parsed, "conflicting resolveInherit for", name); } void resolveInheritByDocument(String name, ParsedSchema parsed) { verifyThat(inheritedByDocument.contains(name), "resolveInheritByDocument for non-inherited name", name); - verifyThat(name.equals(parsed.name()), "resolveInheritByDocument name mismatch for", name); - var old = allResolvedInherits.put(name, parsed); - verifyThat(old == null || old == parsed, "conflicting resolveInherit for", name); + var old = allResolvedInherits.put("document " + name, parsed); + verifyThat(old == null || old == parsed, "conflicting resolveInheritByDocument for", name); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedStruct.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedStruct.java index 17b20459c9c..b5f297cf5da 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedStruct.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedStruct.java @@ -15,10 +15,13 @@ import java.util.Map; public class ParsedStruct extends ParsedBlock { private final List<String> inherited = new ArrayList<>(); private final Map<String, ParsedField> fields = new LinkedHashMap<>(); + private final ParsedType asParsedType; private String ownedBy = null; public ParsedStruct(String name) { super(name, "struct"); + this.asParsedType = ParsedType.fromName(name); + asParsedType.setVariant(ParsedType.Variant.STRUCT); } List<ParsedField> getFields() { return List.copyOf(fields.values()); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedType.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedType.java index 3aed90a58e1..d04277706a1 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedType.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedType.java @@ -46,6 +46,7 @@ class ParsedType { case "raw": return Variant.BUILTIN; case "tag": return Variant.BUILTIN; case "position": return Variant.POSITION; + case "float16": return Variant.BUILTIN; } return Variant.UNKNOWN; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/clustercontroller/ClusterControllerContainer.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/clustercontroller/ClusterControllerContainer.java index 20c3e007e3b..dbc055ef02e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/admin/clustercontroller/ClusterControllerContainer.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/clustercontroller/ClusterControllerContainer.java @@ -28,6 +28,8 @@ import com.yahoo.vespa.model.container.PlatformBundles; import java.util.Set; import java.util.TreeSet; +import static com.yahoo.vespa.model.container.docproc.DocprocChains.DOCUMENT_TYPE_MANAGER_CLASS; + /** * Container implementation for cluster-controllers */ @@ -148,6 +150,8 @@ public class ClusterControllerContainer extends Container implements addComponent("reindexing-maintainer", "ai.vespa.reindexing.ReindexingMaintainer", REINDEXING_CONTROLLER_BUNDLE); + + addComponent(new SimpleComponent(DOCUMENT_TYPE_MANAGER_CLASS)); addHandler("reindexing-status", "ai.vespa.reindexing.http.ReindexingV1ApiHandler", "/reindexing/v1/*", diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java index 8f95e390b07..fae12a63427 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java @@ -36,6 +36,7 @@ import com.yahoo.vespa.model.container.component.Component; import com.yahoo.vespa.model.container.component.Handler; import com.yahoo.vespa.model.container.component.SystemBindingPattern; import com.yahoo.vespa.model.container.configserver.ConfigserverCluster; +import com.yahoo.vespa.model.container.docproc.DocprocChains; import com.yahoo.vespa.model.utils.FileSender; import java.util.ArrayList; @@ -48,6 +49,7 @@ import java.util.Set; import java.util.stream.Collectors; import static com.yahoo.config.model.api.ApplicationClusterEndpoint.RoutingMethod.sharedLayer4; +import static com.yahoo.vespa.model.container.docproc.DocprocChains.DOCUMENT_TYPE_MANAGER_CLASS; /** * A container cluster that is typically set up from the user application. @@ -110,6 +112,7 @@ public final class ApplicationContainerCluster extends ContainerCluster<Applicat addSimpleComponent("com.yahoo.container.jdisc.CertificateStoreProvider"); addSimpleComponent("com.yahoo.container.jdisc.AthenzIdentityProviderProvider"); addSimpleComponent(com.yahoo.container.core.documentapi.DocumentAccessProvider.class.getName()); + addSimpleComponent(DOCUMENT_TYPE_MANAGER_CLASS); addMetricsHandlers(); addTestrunnerComponentsIfTester(deployState); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/docproc/DocprocChains.java b/config-model/src/main/java/com/yahoo/vespa/model/container/docproc/DocprocChains.java index 09995f661e4..4b9897d0950 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/docproc/DocprocChains.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/docproc/DocprocChains.java @@ -7,6 +7,7 @@ import com.yahoo.container.jdisc.config.SessionConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import com.yahoo.vespa.model.container.ContainerCluster; import com.yahoo.vespa.model.container.component.Component; +import com.yahoo.vespa.model.container.component.SimpleComponent; import com.yahoo.vespa.model.container.component.SystemBindingPattern; import com.yahoo.vespa.model.container.component.chain.Chains; import com.yahoo.vespa.model.container.component.chain.ProcessingHandler; @@ -15,19 +16,28 @@ import com.yahoo.vespa.model.container.component.chain.ProcessingHandler; * @author Einar M R Rosenvinge */ public class DocprocChains extends Chains<DocprocChain> { + + public static final String DOCUMENT_TYPE_MANAGER_CLASS = "com.yahoo.document.DocumentTypeManager"; + private final ProcessingHandler<DocprocChains> docprocHandler; - public DocprocChains(AbstractConfigProducer parent, String subId) { + public DocprocChains(AbstractConfigProducer<?> parent, String subId) { super(parent, subId); docprocHandler = new ProcessingHandler<>(this, "com.yahoo.docproc.jdisc.DocumentProcessingHandler"); addComponent(docprocHandler); + + if (! (getParent() instanceof ApplicationContainerCluster)) { + // All application containers already have a DocumentTypeManager, + // but this could also belong to e.g. a cluster controller. + addComponent(new SimpleComponent(DOCUMENT_TYPE_MANAGER_CLASS)); + } } - private void addComponent(Component component) { - if (!(getParent() instanceof ContainerCluster)) { + private void addComponent(Component<?, ?> component) { + if (!(getParent() instanceof ContainerCluster<?>)) { return; } - ((ContainerCluster) getParent()).addComponent(component); + ((ContainerCluster<?>) getParent()).addComponent(component); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java index 2742dc59fcd..88139de7888 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java @@ -36,15 +36,13 @@ import java.util.stream.Collectors; */ public class OnnxModelInfo { - private final ApplicationPackage app; private final String modelPath; private final String defaultOutput; private final Map<String, OnnxTypeInfo> inputs; private final Map<String, OnnxTypeInfo> outputs; private final Map<String, TensorType> vespaTypes = new HashMap<>(); - private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) { - this.app = app; + private OnnxModelInfo(String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) { this.modelPath = path; this.inputs = Collections.unmodifiableMap(inputs); this.outputs = Collections.unmodifiableMap(outputs); @@ -81,15 +79,7 @@ public class OnnxModelInfo { Set<Long> unboundSizes = new HashSet<>(); Map<String, Long> symbolicSizes = new HashMap<>(); resolveUnknownDimensionSizes(inputTypes, symbolicSizes, unboundSizes); - - TensorType type = TensorType.empty; - if (inputTypes.size() > 0 && onnxTypeInfo.needModelProbe(symbolicSizes)) { - type = OnnxModelProbe.probeModel(app, Path.fromString(modelPath), onnxName, inputTypes); - } - if (type.equals(TensorType.empty)) { - type = onnxTypeInfo.toVespaTensorType(symbolicSizes, unboundSizes); - } - return type; + return onnxTypeInfo.toVespaTensorType(symbolicSizes, unboundSizes); } return vespaTypes.computeIfAbsent(onnxName, v -> onnxTypeInfo.toVespaTensorType()); } @@ -160,8 +150,7 @@ public class OnnxModelInfo { Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); String json = onnxModelToJson(model, path); storeGeneratedInfo(json, path, app); - return jsonToModelInfo(json, app); - + return jsonToModelInfo(json); } catch (IOException e) { throw new IllegalArgumentException("Unable to parse ONNX model", e); } @@ -170,7 +159,7 @@ public class OnnxModelInfo { static private OnnxModelInfo loadFromGeneratedInfo(Path path, ApplicationPackage app) { try { String json = readGeneratedInfo(path, app); - return jsonToModelInfo(json, app); + return jsonToModelInfo(json); } catch (IOException e) { throw new IllegalArgumentException("Unable to parse ONNX model", e); } @@ -213,7 +202,7 @@ public class OnnxModelInfo { return out.toString(); } - static public OnnxModelInfo jsonToModelInfo(String json, ApplicationPackage app) throws IOException { + static public OnnxModelInfo jsonToModelInfo(String json) throws IOException { ObjectMapper m = new ObjectMapper(); JsonNode root = m.readTree(json); Map<String, OnnxTypeInfo> inputs = new HashMap<>(); @@ -233,7 +222,7 @@ public class OnnxModelInfo { if (root.get("outputs").has(0)) { defaultOutput = root.get("outputs").get(0).get("name").textValue(); } - return new OnnxModelInfo(app, path, inputs, outputs, defaultOutput); + return new OnnxModelInfo(path, inputs, outputs, defaultOutput); } static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException { @@ -364,21 +353,6 @@ public class OnnxModelInfo { return builder.build(); } - boolean needModelProbe(Map<String, Long> symbolicSizes) { - for (OnnxDimensionInfo onnxDimension : dimensions) { - if (onnxDimension.hasSymbolicName()) { - if (symbolicSizes == null) - return true; - if ( ! symbolicSizes.containsKey(onnxDimension.getSymbolicName())) { - return true; - } - } else if (onnxDimension.getSize() == 0) { - return true; - } - } - return false; - } - @Override public String toString() { return "(" + valueType.id() + ")" + diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java deleted file mode 100644 index 2ef81e3f1fa..00000000000 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java +++ /dev/null @@ -1,153 +0,0 @@ -package com.yahoo.vespa.model.ml; - -import com.fasterxml.jackson.core.JsonEncoding; -import com.fasterxml.jackson.core.JsonFactory; -import com.fasterxml.jackson.core.JsonGenerator; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.yahoo.config.application.api.ApplicationFile; -import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.io.IOUtils; -import com.yahoo.path.Path; -import com.yahoo.tensor.TensorType; - -import java.io.BufferedReader; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.nio.charset.StandardCharsets; -import java.util.Map; - -/** - * Defers to 'vespa-analyze-onnx-model' to determine the output type given - * a set of inputs. For situations with symbolic dimension sizes that can't - * easily be determined. - * - * @author lesters - */ -public class OnnxModelProbe { - - private static final String binary = "vespa-analyze-onnx-model"; - - static TensorType probeModel(ApplicationPackage app, Path modelPath, String outputName, Map<String, TensorType> inputTypes) { - TensorType outputType = TensorType.empty; - String contextKey = createContextKey(outputName, inputTypes); - - try { - // Check if output type has already been probed - outputType = readProbedOutputType(app, modelPath, contextKey); - - // Otherwise, run vespa-analyze-onnx-model if the model is available - if (outputType.equals(TensorType.empty) && app.getFile(modelPath).exists()) { - String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes); - String jsonOutput = callVespaAnalyzeOnnxModel(jsonInput); - outputType = outputTypeFromJson(jsonOutput, outputName); - if ( ! outputType.equals(TensorType.empty)) { - writeProbedOutputType(app, modelPath, contextKey, outputType); - } - } - - } catch (IOException | InterruptedException e) { - e.printStackTrace(System.err); - } - - return outputType; - } - - private static String createContextKey(String onnxName, Map<String, TensorType> inputTypes) { - StringBuilder key = new StringBuilder().append(onnxName).append(":"); - inputTypes.entrySet().stream().sorted(Map.Entry.comparingByKey()) - .forEachOrdered(e -> key.append(e.getKey()).append(":").append(e.getValue()).append(",")); - return key.substring(0, key.length()-1); - } - - private static Path probedOutputTypesPath(Path path) { - String fileName = OnnxModelInfo.asValidIdentifier(path.getRelative()) + ".probed_output_types"; - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName); - } - - static void writeProbedOutputType(ApplicationPackage app, Path modelPath, String output, - Map<String, TensorType> inputTypes, TensorType type) throws IOException { - writeProbedOutputType(app, modelPath, createContextKey(output, inputTypes), type); - } - - private static void writeProbedOutputType(ApplicationPackage app, Path modelPath, - String contextKey, TensorType type) throws IOException { - String path = app.getFileReference(probedOutputTypesPath(modelPath)).getAbsolutePath(); - IOUtils.writeFile(path, contextKey + "\t" + type + "\n", true); - } - - private static TensorType readProbedOutputType(ApplicationPackage app, Path modelPath, - String contextKey) throws IOException { - ApplicationFile file = app.getFile(probedOutputTypesPath(modelPath)); - if ( ! file.exists()) { - return TensorType.empty; - } - try (BufferedReader reader = new BufferedReader(file.createReader())) { - String line; - while (null != (line = reader.readLine())) { - String[] parts = line.split("\t"); - String key = parts[0]; - if (key.equals(contextKey)) { - return TensorType.fromSpec(parts[1]); - } - } - } - return TensorType.empty; - } - - private static TensorType outputTypeFromJson(String json, String outputName) throws IOException { - ObjectMapper m = new ObjectMapper(); - JsonNode root = m.readTree(json); - if ( ! root.isObject() || ! root.has("outputs")) { - return TensorType.empty; - } - JsonNode outputs = root.get("outputs"); - if ( ! outputs.has(outputName)) { - return TensorType.empty; - } - return TensorType.fromSpec(outputs.get(outputName).asText()); - } - - private static String createJsonInput(String modelPath, Map<String, TensorType> inputTypes) throws IOException { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8); - g.writeStartObject(); - g.writeStringField("model", modelPath); - g.writeObjectFieldStart("inputs"); - for (Map.Entry<String, TensorType> input : inputTypes.entrySet()) { - g.writeStringField(input.getKey(), input.getValue().toString()); - } - g.writeEndObject(); - g.writeEndObject(); - g.close(); - return out.toString(); - } - - private static String callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException { - ProcessBuilder processBuilder = new ProcessBuilder(binary, "--probe-types"); - processBuilder.redirectErrorStream(true); - StringBuilder output = new StringBuilder(); - Process process = processBuilder.start(); - - // Write json array to process stdin - OutputStream os = process.getOutputStream(); - os.write(jsonInput.getBytes(StandardCharsets.UTF_8)); - os.close(); - - // Read output from stdout/stderr - InputStream inputStream = process.getInputStream(); - while (true) { - int b = inputStream.read(); - if (b == -1) break; - output.append((char)b); - } - int returnCode = process.waitFor(); - if (returnCode != 0) { - throw new IllegalArgumentException("Error from '" + binary + "'. Return code: " + returnCode + ". Output:\n" + output); - } - return output.toString(); - } - -} |