diff options
-rw-r--r-- | python/vespa/vespa/application.py | 8 | ||||
-rw-r--r-- | python/vespa/vespa/package.py | 193 |
2 files changed, 196 insertions, 5 deletions
diff --git a/python/vespa/vespa/application.py b/python/vespa/vespa/application.py index 3ab31c4dd8c..6ccf5e8d854 100644 --- a/python/vespa/vespa/application.py +++ b/python/vespa/vespa/application.py @@ -15,6 +15,7 @@ class Vespa(object): url: str, port: Optional[int] = None, deployment_message: Optional[List[str]] = None, + cert: Optional[str] = None, ) -> None: """ Establish a connection with a Vespa application. @@ -22,14 +23,17 @@ class Vespa(object): :param url: URL :param port: Port :param deployment_message: Message returned by Vespa engine after deployment. + :param cert: Path to certificate and key file >>> Vespa(url = "https://cord19.vespa.ai") >>> Vespa(url = "http://localhost", port = 8080) + >>> Vespa(url = "https://api.vespa-external.aws.oath.cloud", port = 4443, cert = "/path/to/cert-and-key.pem") """ self.url = url self.port = port self.deployment_message = deployment_message + self.cert = cert if port is None: self.end_point = self.url @@ -87,7 +91,7 @@ class Vespa(object): if debug_request: return VespaResult(vespa_result={}, request_body=body) else: - r = post(self.search_end_point, json=body) + r = post(self.search_end_point, json=body, cert=self.cert) return VespaResult(vespa_result=r.json()) def feed_data_point(self, schema: str, data_id: str, fields: Dict) -> Response: @@ -103,7 +107,7 @@ class Vespa(object): self.end_point, schema, schema, str(data_id) ) vespa_format = {"fields": fields} - response = post(end_point, json=vespa_format) + response = post(end_point, json=vespa_format, cert=self.cert) return response def collect_training_data_point( diff --git a/python/vespa/vespa/package.py b/python/vespa/vespa/package.py index 4b5d1e701d5..03f394f4aa5 100644 --- a/python/vespa/vespa/package.py +++ b/python/vespa/vespa/package.py @@ -1,11 +1,23 @@ +import http.client +import json import os import re -from time import sleep -from typing import List, Mapping, Optional +import tempfile +import zipfile +from base64 import standard_b64encode +from datetime import datetime, timedelta +from io import BytesIO from pathlib import Path +from time import sleep, strftime, gmtime +from typing import List, Mapping, Optional -from jinja2 import Environment, PackageLoader, select_autoescape import docker +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives import serialization +from jinja2 import Environment, PackageLoader, select_autoescape from vespa.json_serialization import ToJson, FromJson from vespa.application import Vespa @@ -358,6 +370,16 @@ class ApplicationPackage(ToJson, FromJson["ApplicationPackage"]): with open(os.path.join(dir_path, "application/services.xml"), "w") as f: f.write(self.services_to_text) + def to_application_zip(self, extras: dict = {}) -> BytesIO: + buffer = BytesIO() + with zipfile.ZipFile(buffer, 'a') as zip_archive: + zip_archive.writestr("application/schemas/{}.sd".format(self.schema.name), self.schema_to_text) + zip_archive.writestr("application/services.xml", self.services_to_text) + for name, value in extras.items(): + zip_archive.writestr(name, value) + return buffer + + @staticmethod def from_dict(mapping: Mapping) -> "ApplicationPackage": schema = mapping.get("schema", None) @@ -468,3 +490,168 @@ class VespaDocker(object): port=self.local_port, deployment_message=deployment_message, ) + + +class VespaCloud(object): + def __init__(self, tenant: str, application: str, key_location: str) -> None: + """ + Deploy application to the Vespa Cloud (cloud.vespa.ai) + + :param tenant: Tenant name registered in the Vespa Cloud. + :param application: Application name registered in the Vespa Cloud. + :param key_location: Location of the private key used for signing HTTP requests to the Vespa Cloud. + """ + self.tenant = tenant + self.application = application + self.api_key = self.read_private_key(key_location) + self.api_public_key_bytes = standard_b64encode(self.api_key.public_key().public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo)) + self.data_key, self.data_certificate = self.create_certificate_pair() + self.data_cert_file = self.write_private_key_and_cert(self.data_key, self.data_certificate) + self.connection = http.client.HTTPSConnection('api.vespa-external.aws.oath.cloud', 4443) + + @staticmethod + def read_private_key(key_location: str) -> ec.EllipticCurvePrivateKey: + with open(key_location, 'rb') as key_data: + key = serialization.load_pem_private_key(key_data.read(), None, default_backend()) + if not isinstance(key, ec.EllipticCurvePrivateKey): + raise TypeError("Key at " + key_location + " must be an elliptic curve private key") + return key + + @staticmethod + def write_private_key_and_cert(key: ec.EllipticCurvePrivateKey, cert: x509.Certificate) -> (str, str): + cert_file = tempfile.NamedTemporaryFile('wt') + cert_file.write(key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.TraditionalOpenSSL, serialization.NoEncryption()).decode('UTF-8')) + cert_file.write(cert.public_bytes(serialization.Encoding.PEM).decode('UTF-8')) + cert_file.flush() + return cert_file + + @staticmethod + def create_certificate_pair() -> (ec.EllipticCurvePrivateKey, x509.Certificate): + key = ec.generate_private_key(ec.SECP384R1, default_backend()) + name = x509.Name([x509.NameAttribute(x509.NameOID.COMMON_NAME, u'localhost')]) + certificate = x509.CertificateBuilder() \ + .subject_name(name) \ + .issuer_name(name) \ + .serial_number(x509.random_serial_number()) \ + .not_valid_before(datetime.utcnow() - timedelta(minutes = 1)) \ + .not_valid_after(datetime.utcnow() + timedelta(days = 7)) \ + .public_key(key.public_key()) \ + .sign(key, hashes.SHA256(), default_backend()) + return (key, certificate) + + def request(self, method: str, path: str, body: BytesIO = BytesIO(), headers = {}) -> dict: + digest = hashes.Hash(hashes.SHA256(), default_backend()) + body.seek(0) + digest.update(body.read()) + content_hash = standard_b64encode(digest.finalize()).decode('UTF-8') + timestamp = datetime.utcnow().isoformat() + 'Z' # Java's Instant.parse requires the neutral time zone appended + url = 'https://' + self.connection.host + ":" + str(self.connection.port) + path + + canonical_message = method + '\n' + url + '\n' + timestamp + '\n' + content_hash + signature = self.api_key.sign(canonical_message.encode('UTF-8'), ec.ECDSA(hashes.SHA256())) + + headers = { + "X-Timestamp": timestamp, + "X-Content-Hash": content_hash, + "X-Key-Id": self.tenant + ':' + self.application + ':' + "default", + "X-Key": self.api_public_key_bytes, + "X-Authorization": standard_b64encode(signature), + **headers + } + + body.seek(0) + self.connection.request(method, path, body, headers) + with self.connection.getresponse() as response: + parsed = json.load(response) + if response.status != 200: + raise RuntimeError("Status code " + str(response.status) + " doing " + method + " at " + url + ":\n" + parsed['message']) + return parsed + + def get_dev_region(self) -> str: + return self.request('GET', '/zone/v1/environment/dev/default')['name'] + + def get_endpoint(self, instance: str, region: str) -> str: + endpoints = self.request('GET', + '/application/v4/tenant/{}/application/{}/instance/{}/environment/dev/region/{}' \ + .format(self.tenant, self.application, instance, region))['endpoints'] + container_url = [endpoint['url'] for endpoint in endpoints if endpoint['cluster'] == 'test_app_container'] + if not container_url: + raise RuntimeError("No endpoints found for container 'test_app_container'") + return container_url[0] + + def start_deployment(self, instance: str, job: str, application_package: ApplicationPackage) -> int: + deploy_path = '/application/v4/tenant/{}/application/{}/instance/{}/deploy/{}' \ + .format(self.tenant, self.application, instance, job) + application_zip_bytes = application_package.to_application_zip({ 'application/security/clients.pem': self.data_certificate.public_bytes(serialization.Encoding.PEM) }) + response = self.request('POST', deploy_path, application_zip_bytes, { 'Content-Type': 'application/zip' }) + print(response['message']) + return response['run'] + + def follow_deployment(self, instance: str, job: str, run: int): + last = -1 + while True: + update = self.request('GET', + '/application/v4/tenant/{}/application/{}/instance/{}/job/{}/run/{}?after={}' \ + .format(self.tenant, self.application, instance, job, run, last)) + + for step, entries in update['log'].items(): + for entry in entries: + self.print_log_entry(step, entry) + last = update.get('lastId', last) + + if update['active']: + sleep(1) + else: + status = update['status'] + if status == 'success': return + elif status == 'error': raise RuntimeError("Unexpected error during deployment; see log for details") + elif status == 'aborted': raise RuntimeError("Deployment was aborted, probably by a newer deployment") + elif status == 'outOfCapacity': raise RuntimeError("No capacity left in zone; please contact the Vespa team") + elif status == 'deploymentFailed': raise RuntimeError("Deployment failed; see log for details") + elif status == 'installationFailed': raise RuntimeError("Installation failed; see Vespa log for details") + elif status == 'running': raise RuntimeError("Deployment not completed") + elif status == 'endpointCertificateTimeout': raise RuntimeError("Endpoint certificate not ready in time; please contact Vespa team") + elif status == 'testFailure': raise RuntimeError("Unexpected status; tests are not run for manual deployments") + else: raise RuntimeError("Unexpected status '" + status + "'") + + @staticmethod + def print_log_entry(step: str, entry: dict): + timestamp = strftime('%H:%M:%S', gmtime(entry['at'] / 1e3)) + message = entry['message'].replace('\n', '\n' + ' '*23) + if step != 'copyVespaLogs' or entry['type'] == 'error': + print('{:<7} [{}] {}'.format(entry['type'].upper(), timestamp, message)) + + def deploy(self, instance: str, application_package: ApplicationPackage) -> Vespa: + """ + Deploy the given application package as the given instance in the Vespa Cloud dev environment. + + :param instance: Name of this instance of the application, in the Vespa Cloud. + :param application_package: ApplicationPackage to be deployed. + + :return: a Vespa connection instance. + """ + region = self.get_dev_region() + job = 'dev-' + region + run = self.start_deployment(instance, job, application_package) + self.follow_deployment(instance, job, run) + endpoint_url = self.get_endpoint(instance, region) + return Vespa(url = endpoint_url, cert = self.data_cert_file.name) + + def delete(self, instance: str): + """ + Delete the specified instance from the dev environment in the Vespa Cloud. + :param instance: Name of the instance to delete. + :return: + """ + print(self.request('DELETE', '/application/v4/tenant/{}/application/{}/instance/{}/environment/dev/region/{}' \ + .format(self.tenant, self.application, instance, self.get_dev_region()))['message']) + + def close(self): + self.connection.close() + self.data_cert_file.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() |