summaryrefslogtreecommitdiffstats
path: root/python
diff options
context:
space:
mode:
authorJon Marius Venstad <venstad@gmail.com>2020-08-27 14:05:06 +0200
committerJon Marius Venstad <venstad@gmail.com>2020-08-27 14:05:06 +0200
commit657ccdfad595192340758093a1eef0aab4c12310 (patch)
tree473df60351c92e4a79e490475177b50f2255a387 /python
parentfdec75f5e841b06015ebb1801b7f73ee87caa934 (diff)
Talk to Vespa Cloud with API key, deploy package with mTLS, etc.
Diffstat (limited to 'python')
-rw-r--r--python/vespa/vespa/package.py185
1 files changed, 182 insertions, 3 deletions
diff --git a/python/vespa/vespa/package.py b/python/vespa/vespa/package.py
index 4b5d1e701d5..3deb4ec8c50 100644
--- a/python/vespa/vespa/package.py
+++ b/python/vespa/vespa/package.py
@@ -1,11 +1,23 @@
+import http.client
+import json
import os
import re
-from time import sleep
-from typing import List, Mapping, Optional
+import tempfile
+import zipfile
+from base64 import standard_b64encode
+from datetime import datetime, timedelta
+from io import BytesIO
from pathlib import Path
+from time import sleep, strftime, gmtime
+from typing import List, Mapping, Optional
-from jinja2 import Environment, PackageLoader, select_autoescape
import docker
+from cryptography import x509
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives.asymmetric import ec
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives import serialization
+from jinja2 import Environment, PackageLoader, select_autoescape
from vespa.json_serialization import ToJson, FromJson
from vespa.application import Vespa
@@ -358,6 +370,16 @@ class ApplicationPackage(ToJson, FromJson["ApplicationPackage"]):
with open(os.path.join(dir_path, "application/services.xml"), "w") as f:
f.write(self.services_to_text)
+ def to_application_zip(self, extras: dict = {}) -> BytesIO:
+ buffer = BytesIO()
+ with zipfile.ZipFile(buffer, 'a') as zip_archive:
+ zip_archive.writestr("application/schemas/{}.sd".format(self.schema.name), self.schema_to_text)
+ zip_archive.writestr("application/services.xml", self.services_to_text)
+ for name, value in extras.items():
+ zip_archive.writestr(name, value)
+ return buffer
+
+
@staticmethod
def from_dict(mapping: Mapping) -> "ApplicationPackage":
schema = mapping.get("schema", None)
@@ -468,3 +490,160 @@ class VespaDocker(object):
port=self.local_port,
deployment_message=deployment_message,
)
+
+
+class VespaCloud(object):
+ def __init__(self, tenant: str, application: str, key_location: str) -> None:
+ """
+ Deploy application to the Vespa Cloud (cloud.vespa.ai)
+
+ :param tenant: Tenant name registered in the Vespa Cloud.
+ :param application: Application name registered in the Vespa Cloud.
+ :param key_location: Location of the private key used for signing HTTP requests to the Vespa Cloud.
+ """
+ self.tenant = tenant
+ self.application = application
+ self.api_key = self.read_private_key(key_location)
+ self.api_public_key_bytes = standard_b64encode(self.api_key.public_key().public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo))
+ self.data_key, self.data_certificate = self.create_certificate_pair()
+ self.data_cert_file = self.write_private_key_and_cert(self.data_key, self.data_certificate)
+ self.connection = http.client.HTTPSConnection('api.vespa-external.aws.oath.cloud', 4443)
+
+ @staticmethod
+ def read_private_key(key_location: str) -> ec.EllipticCurvePrivateKey:
+ with open(key_location, 'rb') as key_data:
+ key = serialization.load_pem_private_key(key_data.read(), None, default_backend())
+ if not isinstance(key, ec.EllipticCurvePrivateKey):
+ raise TypeError("Key at " + key_location + " must be an elliptic curve private key")
+ return key
+
+ @staticmethod
+ def write_private_key_and_cert(key: ec.EllipticCurvePrivateKey, cert: x509.Certificate) -> (str, str):
+ cert_file = tempfile.NamedTemporaryFile('wt')
+ cert_file.write(key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.TraditionalOpenSSL, serialization.NoEncryption()).decode('UTF-8'))
+ cert_file.write(cert.public_bytes(serialization.Encoding.PEM).decode('UTF-8'))
+ cert_file.flush()
+ return cert_file
+
+ @staticmethod
+ def create_certificate_pair() -> (ec.EllipticCurvePrivateKey, x509.Certificate):
+ key = ec.generate_private_key(ec.SECP521R1, default_backend())
+ name = x509.Name([x509.NameAttribute(x509.NameOID.COMMON_NAME, u'localhost')])
+ certificate = x509.CertificateBuilder() \
+ .subject_name(name) \
+ .issuer_name(name) \
+ .serial_number(x509.random_serial_number()) \
+ .not_valid_before(datetime.utcnow() - timedelta(minutes = 1)) \
+ .not_valid_after(datetime.utcnow() + timedelta(days = 7)) \
+ .public_key(key.public_key()) \
+ .sign(key, hashes.SHA256(), default_backend())
+ return (key, certificate)
+
+ def request(self, method: str, path: str, body: BytesIO = BytesIO(), headers = {}) -> dict:
+ digest = hashes.Hash(hashes.SHA256(), default_backend())
+ body.seek(0)
+ digest.update(body.read())
+ content_hash = standard_b64encode(digest.finalize()).decode('UTF-8')
+ timestamp = datetime.utcnow().isoformat() + 'Z' # Java's Instant.parse requires the neutral time zone appended
+ url = 'https://' + self.connection.host + ":" + str(self.connection.port) + path
+
+ canonical_message = method + '\n' + url + '\n' + timestamp + '\n' + content_hash
+ signature = self.api_key.sign(canonical_message.encode('UTF-8'), ec.ECDSA(hashes.SHA256()))
+
+ headers = {
+ "X-Timestamp": timestamp,
+ "X-Content-Hash": content_hash,
+ "X-Key-Id": self.tenant + ':' + self.application + ':' + "default",
+ "X-Key": self.api_public_key_bytes,
+ "X-Authorization": standard_b64encode(signature),
+ **headers
+ }
+
+ body.seek(0)
+ self.connection.request(method, path, body, headers)
+ with self.connection.getresponse() as response:
+ parsed = json.load(response)
+ if response.status != 200:
+ raise RuntimeError("Status code " + str(response.status) + " doing " + method + " at " + url + ":\n" + parsed['message'])
+ return parsed
+
+ def get_dev_region(self) -> str:
+ return self.request('GET', '/zone/v1/environment/dev/default')['name']
+
+ def get_endpoint(self, instance: str, region: str) -> str:
+ endpoints = self.request('GET',
+ '/application/v4/tenant/{}/application/{}/instance/{}/environment/dev/region/{}' \
+ .format(self.tenant, self.application, instance, region))['endpoints']
+ container_url = [endpoint['url'] for endpoint in endpoints if endpoint['cluster'] == 'test_app_container']
+ if not container_url:
+ raise RuntimeError("No endpoints found for container 'test_app_container'")
+ return container_url[0]
+
+ def start_deployment(self, instance: str, job: str, application_package: ApplicationPackage) -> int:
+ deploy_path = '/application/v4/tenant/{}/application/{}/instance/{}/deploy/{}' \
+ .format(self.tenant, self.application, instance, job)
+ application_zip_bytes = application_package.to_application_zip({ 'application/security/clients.pem': self.data_certificate.public_bytes(serialization.Encoding.PEM) })
+ response = self.request('POST', deploy_path, application_zip_bytes, { 'Content-Type': 'application/zip' })
+ print(response['message'])
+ return response['run']
+
+ def follow_deployment(self, instance: str, job: str, run: int):
+ last = -1
+ while True:
+ update = self.request('GET',
+ '/application/v4/tenant/{}/application/{}/instance/{}/job/{}/run/{}?after={}' \
+ .format(self.tenant, self.application, instance, job, run, last))
+
+ for step, entries in update['log'].items():
+ for entry in entries:
+ self.print_log_entry(step, entry)
+ last = update.get('lastId', last)
+
+ if update['active']:
+ sleep(1)
+ else:
+ status = update['status']
+ if status == 'success': return
+ elif status == 'error': raise RuntimeError("Unexpected error during deployment; see log for details")
+ elif status == 'aborted': raise RuntimeError("Deployment was aborted, probably by a newer deployment")
+ elif status == 'outOfCapacity': raise RuntimeError("No capacity left in zone; please contact the Vespa team")
+ elif status == 'deploymentFailed': raise RuntimeError("Deployment failed; see log for details")
+ elif status == 'installationFailed': raise RuntimeError("Installation failed; see Vespa log for details")
+ elif status == 'running': raise RuntimeError("Deployment not completed")
+ elif status == 'endpointCertificateTimeout': raise RuntimeError("Endpoint certificate not ready in time; please contact Vespa team")
+ elif status == 'testFailure': raise RuntimeError("Unexpected status; tests are not run for manual deployments")
+ else: raise RuntimeError("Unexpected status '" + status + "'")
+
+ @staticmethod
+ def print_log_entry(step: str, entry: dict):
+ timestamp = strftime('%H:%M:%S', gmtime(entry['at'] / 1e3))
+ message = entry['message'].replace('\n', '\n' + ' '*23)
+ if step != 'copyVespaLogs' or entry['type'] == 'error':
+ print('{:<7} [{}] {}'.format(entry['type'].upper(), timestamp, message))
+
+ def deploy(self, instance: str, application_package: ApplicationPackage) -> Vespa:
+ """
+ Deploy the given application package as the given instance in the Vespa Cloud dev environment.
+
+ :param instance: Name of this instance of the application, in the Vespa Cloud.
+ :param application_package: ApplicationPackage to be deployed.
+
+ :return: a Vespa connection instance.
+ """
+
+ region = self.get_dev_region()
+ job = 'dev-' + region
+ run = self.start_deployment(instance, job, application_package)
+ self.follow_deployment(instance, job, run)
+ endpoint_url = self.get_endpoint(instance, region)
+ return Vespa(url = endpoint_url, cert = self.data_cert_file.name)
+
+ def close(self):
+ self.connection.close()
+ self.data_cert_file.close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()