aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java40
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java2
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java89
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationInstanceAuthorizer.java (renamed from controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/DeployAuthorizer.java)93
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/Authorizer.java22
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/securitycontext/CreateSecurityContextFilter.java2
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/securitycontext/PropagateSecurityContextFilter.java2
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java33
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/MockAuthorizer.java18
-rw-r--r--node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/noderepository/NodeRepositoryImplTest.java20
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java6
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java1
14 files changed, 181 insertions, 155 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 83cc3ae418a..3e11eb72a30 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -250,46 +250,6 @@ public class RankingExpressionWithTensorFlowTestCase {
}
}
- @Test
- public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException {
- final String expression = "join(rename(reduce(join(map(join(rename(reduce(join(join(join(constant(\"dnn_hidden1_mul_x\"), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_hidden2_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden2_bias\"), d0, d1), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), rename(constant(\"dnn_outputs_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_outputs_bias\"), d0, d1), f(a,b)(a + b))";
- StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
- RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
- "tensorflow('mnist/saved')",
- null,
- null,
- "input",
- application);
- search.assertFirstPhaseExpression(expression, "my_profile");
- assertSmallConstant("dnn_hidden1_mul_x", TensorType.empty, search);
-
- // At this point the expression is stored - copy application to another location which do not have a models dir
- Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
- try {
- storedApplicationDirectory.toFile().mkdirs();
- IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
- storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
- StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
- RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)",
- "tensorflow('mnist/saved')",
- null,
- null,
- "input",
- storedApplication);
- searchFromStored.assertFirstPhaseExpression(expression, "my_profile");
- assertSmallConstant("dnn_hidden1_mul_x", TensorType.empty, search);
- }
- finally {
- IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
- }
- }
-
- private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) {
- Value value = search.rankProfile("my_profile").getConstants().get(name);
- assertNotNull(value);
- assertEquals(type, value.type());
- }
-
/**
* Verifies that the constant with the given name exists, and - only if an expected size is given -
* that the content of the constant is available and has the expected size.
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
index c18cfcfe1aa..b001db69768 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
@@ -64,7 +64,7 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase {
assertContainsExpression("min(attribute(tensor_field_1) * attribute(tensor_field_2),x)", "reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)");
assertContainsExpression("min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)");
assertContainsExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...)
- assertContainsExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)");
+ assertContainsExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),min,x)");
}
@Test
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
index 2105ea1d3d9..6fc65253da3 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
@@ -10,7 +10,6 @@ import com.yahoo.config.provision.ApplicationName;
import com.yahoo.config.provision.Environment;
import com.yahoo.config.provision.RegionName;
import com.yahoo.config.provision.TenantName;
-import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId;
import com.yahoo.container.jdisc.HttpRequest;
import com.yahoo.container.jdisc.HttpResponse;
import com.yahoo.container.jdisc.LoggingRequestHandler;
@@ -19,6 +18,10 @@ import com.yahoo.log.LogLevel;
import com.yahoo.slime.Cursor;
import com.yahoo.slime.Inspector;
import com.yahoo.slime.Slime;
+import com.yahoo.vespa.athenz.api.AthenzDomain;
+import com.yahoo.vespa.athenz.api.AthenzIdentity;
+import com.yahoo.vespa.athenz.api.AthenzUser;
+import com.yahoo.vespa.athenz.api.NToken;
import com.yahoo.vespa.config.SlimeUtils;
import com.yahoo.vespa.hosted.controller.AlreadyExistsException;
import com.yahoo.vespa.hosted.controller.Application;
@@ -36,7 +39,6 @@ import com.yahoo.vespa.hosted.controller.api.application.v4.model.ScrewdriverBui
import com.yahoo.vespa.hosted.controller.api.application.v4.model.configserverbindings.RefeedAction;
import com.yahoo.vespa.hosted.controller.api.application.v4.model.configserverbindings.RestartAction;
import com.yahoo.vespa.hosted.controller.api.application.v4.model.configserverbindings.ServiceInfo;
-import com.yahoo.vespa.athenz.api.AthenzDomain;
import com.yahoo.vespa.hosted.controller.api.identifiers.DeploymentId;
import com.yahoo.vespa.hosted.controller.api.identifiers.GitBranch;
import com.yahoo.vespa.hosted.controller.api.identifiers.GitCommit;
@@ -48,10 +50,13 @@ import com.yahoo.vespa.hosted.controller.api.identifiers.ScrewdriverId;
import com.yahoo.vespa.hosted.controller.api.identifiers.TenantId;
import com.yahoo.vespa.hosted.controller.api.identifiers.UserGroup;
import com.yahoo.vespa.hosted.controller.api.identifiers.UserId;
+import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzClientFactory;
+import com.yahoo.vespa.hosted.controller.api.integration.athenz.ZmsException;
import com.yahoo.vespa.hosted.controller.api.integration.configserver.ConfigServerException;
import com.yahoo.vespa.hosted.controller.api.integration.configserver.Log;
import com.yahoo.vespa.hosted.controller.api.integration.organization.User;
import com.yahoo.vespa.hosted.controller.api.integration.routing.RotationStatus;
+import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId;
import com.yahoo.vespa.hosted.controller.application.ApplicationPackage;
import com.yahoo.vespa.hosted.controller.application.ApplicationVersion;
import com.yahoo.vespa.hosted.controller.application.Change;
@@ -62,12 +67,6 @@ import com.yahoo.vespa.hosted.controller.application.DeploymentCost;
import com.yahoo.vespa.hosted.controller.application.DeploymentMetrics;
import com.yahoo.vespa.hosted.controller.application.JobStatus;
import com.yahoo.vespa.hosted.controller.application.SourceRevision;
-import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzClientFactory;
-import com.yahoo.vespa.athenz.api.AthenzIdentity;
-import com.yahoo.vespa.athenz.api.AthenzPrincipal;
-import com.yahoo.vespa.athenz.api.AthenzUser;
-import com.yahoo.vespa.athenz.api.NToken;
-import com.yahoo.vespa.hosted.controller.api.integration.athenz.ZmsException;
import com.yahoo.vespa.hosted.controller.restapi.ErrorResponse;
import com.yahoo.vespa.hosted.controller.restapi.MessageResponse;
import com.yahoo.vespa.hosted.controller.restapi.Path;
@@ -85,7 +84,6 @@ import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
-import java.security.Principal;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
@@ -107,6 +105,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
private final Controller controller;
private final Authorizer authorizer;
private final AthenzClientFactory athenzClientFactory;
+ private final ApplicationInstanceAuthorizer applicationInstanceAuthorizer;
@Inject
public ApplicationApiHandler(LoggingRequestHandler.Context parentCtx,
@@ -116,6 +115,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
this.controller = controller;
this.authorizer = authorizer;
this.athenzClientFactory = athenzClientFactory;
+ this.applicationInstanceAuthorizer = new ApplicationInstanceAuthorizer(controller.zoneRegistry(), athenzClientFactory);
}
@Override
@@ -192,13 +192,13 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
Path path = new Path(request.getUri().getPath());
if (path.matches("/application/v4/tenant/{tenant}")) return createTenant(path.get("tenant"), request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}")) return createApplication(path.get("tenant"), path.get("application"), request);
- if (path.matches("/application/v4/tenant/{tenant}/application/{application}/promote")) return promoteApplication(path.get("tenant"), path.get("application"));
+ if (path.matches("/application/v4/tenant/{tenant}/application/{application}/promote")) return promoteApplication(path.get("tenant"), path.get("application"), request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/deploying")) return deploy(path.get("tenant"), path.get("application"), request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}")) return deploy(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/deploy")) return deploy(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), request); // legacy synonym of the above
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/restart")) return restart(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), request);
- if (path.matches("/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/log")) return log(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"));
- if (path.matches("/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/promote")) return promoteApplicationDeployment(path.get("tenant"), path.get("application"), path.get("environment"), path.get("region"));
+ if (path.matches("/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/log")) return log(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), request);
+ if (path.matches("/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/promote")) return promoteApplicationDeployment(path.get("tenant"), path.get("application"), path.get("environment"), path.get("region"), path.get("instance"), request);
return ErrorResponse.notFoundError("Nothing at " + path);
}
@@ -207,7 +207,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
if (path.matches("/application/v4/tenant/{tenant}")) return deleteTenant(path.get("tenant"), request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}")) return deleteApplication(path.get("tenant"), path.get("application"), request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/deploying")) return cancelDeploy(path.get("tenant"), path.get("application"));
- if (path.matches("/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}")) return deactivate(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"));
+ if (path.matches("/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}")) return deactivate(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/global-rotation/override"))
return setGlobalRotationOverride(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), true, request);
return ErrorResponse.notFoundError("Nothing at " + path);
@@ -238,7 +238,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
private HttpResponse authenticatedUser(HttpRequest request) {
String userIdString = request.getProperty("userOverride");
if (userIdString == null)
- userIdString = userFrom(request)
+ userIdString = authorizer.getUserId(request)
.map(UserId::id)
.orElseThrow(() -> new ForbiddenException("You must be authenticated or specify userOverride"));
UserId userId = new UserId(userIdString);
@@ -594,8 +594,8 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
}
private HttpResponse createUser(HttpRequest request) {
- Optional<UserId> user = userFrom(request);
- if ( ! user.isPresent() ) throw new ForbiddenException("Not authenticated.");
+ Optional<UserId> user = authorizer.getUserId(request);
+ if ( ! user.isPresent() ) throw new ForbiddenException("Not authenticated or not an user.");
try {
controller.tenants().createUserTenant(user.get().id());
@@ -700,6 +700,8 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
}
/** Trigger deployment of the last built application package, on a given version */
+ // TODO Add authorization
+ // TODO Consider move to API for maintenance related operations
private HttpResponse deploy(String tenantName, String applicationName, HttpRequest request) {
Version version = decideDeployVersion(request);
if ( ! systemHasVersion(version))
@@ -719,6 +721,8 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
}
/** Cancel any ongoing change for given application */
+ // TODO Add authorization
+ // TODO Consider move to API for maintenance related operations
private HttpResponse cancelDeploy(String tenantName, String applicationName) {
ApplicationId id = ApplicationId.from(tenantName, applicationName, "default");
Application application = controller.applications().require(id);
@@ -736,8 +740,14 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
private HttpResponse restart(String tenantName, String applicationName, String instanceName, String environment, String region, HttpRequest request) {
DeploymentId deploymentId = new DeploymentId(ApplicationId.from(tenantName, applicationName, instanceName),
ZoneId.from(environment, region));
+
// TODO: Propagate all filters
Optional<Hostname> hostname = Optional.ofNullable(request.getProperty("hostname")).map(Hostname::new);
+
+ applicationInstanceAuthorizer.throwIfUnauthorized(authorizer.getPrincipal(request),
+ Environment.from(environment),
+ getTenantOrThrow(tenantName),
+ deploymentId.applicationId().application());
controller.applications().restart(deploymentId, hostname);
// TODO: Change to return JSON
@@ -753,10 +763,15 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
* the application is working. It is called for all production zones, also those in which the application is not present,
* and possibly before it is present, so failures are normal and expected.
*/
- private HttpResponse log(String tenantName, String applicationName, String instanceName, String environment, String region) {
+ private HttpResponse log(String tenantName, String applicationName, String instanceName, String environment, String region, HttpRequest request) {
try {
DeploymentId deploymentId = new DeploymentId(ApplicationId.from(tenantName, applicationName, instanceName),
ZoneId.from(environment, region));
+
+ applicationInstanceAuthorizer.throwIfUnauthorized(authorizer.getPrincipal(request),
+ Environment.from(environment),
+ getTenantOrThrow(tenantName),
+ deploymentId.applicationId().application());
return new JacksonJsonResponse(controller.grabLog(deploymentId));
}
catch (RuntimeException e) {
@@ -778,10 +793,11 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
Optional<ApplicationPackage> applicationPackage = Optional.ofNullable(dataParts.get("applicationZip"))
.map(ApplicationPackage::new);
- DeployAuthorizer deployAuthorizer = new DeployAuthorizer(controller.zoneRegistry(), athenzClientFactory);
- Tenant tenant = controller.tenants().tenant(new TenantId(tenantName)).orElseThrow(() -> new NotExistsException(new TenantId(tenantName)));
- Principal principal = authorizer.getPrincipal(request);
- deployAuthorizer.throwIfUnauthorizedForDeploy(principal, Environment.from(environment), tenant, applicationId, applicationPackage);
+ applicationInstanceAuthorizer.throwIfUnauthorizedForDeploy(authorizer.getPrincipal(request),
+ Environment.from(environment),
+ getTenantOrThrow(tenantName),
+ ApplicationName.from(applicationName),
+ applicationPackage);
// TODO: get rid of the json object
DeployOptions deployOptionsJsonClass = new DeployOptions(screwdriverBuildJobFromSlime(deployOptions.field("screwdriverBuildJob")),
@@ -814,11 +830,17 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
return new EmptyJsonResponse(); // TODO: Replicates current behavior but should return a message response instead
}
- private HttpResponse deactivate(String tenantName, String applicationName, String instanceName, String environment, String region) {
+ private HttpResponse deactivate(String tenantName, String applicationName, String instanceName, String environment, String region, HttpRequest request) {
Application application = controller.applications().require(ApplicationId.from(tenantName, applicationName, instanceName));
ZoneId zone = ZoneId.from(environment, region);
Deployment deployment = application.deployments().get(zone);
+
+ applicationInstanceAuthorizer.throwIfUnauthorized(authorizer.getPrincipal(request),
+ Environment.from(environment),
+ getTenantOrThrow(tenantName),
+ ApplicationName.from(applicationName));
+
if (deployment == null) {
// Attempt to deactivate application even if the deployment is not known by the controller
controller.applications().deactivate(application, zone);
@@ -837,8 +859,12 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
/**
* Promote application Chef environments. To be used by component jobs only
*/
- private HttpResponse promoteApplication(String tenantName, String applicationName) {
+ private HttpResponse promoteApplication(String tenantName, String applicationName, HttpRequest request) {
try{
+ applicationInstanceAuthorizer.throwIfUnauthorized(authorizer.getPrincipal(request),
+ getTenantOrThrow(tenantName),
+ ApplicationName.from(applicationName));
+
ApplicationChefEnvironment chefEnvironment = new ApplicationChefEnvironment(controller.system());
String sourceEnvironment = chefEnvironment.systemChefEnvironment();
String targetEnvironment = chefEnvironment.applicationSourceEnvironment(TenantName.from(tenantName), ApplicationName.from(applicationName));
@@ -853,8 +879,13 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
/**
* Promote application Chef environments for jobs that deploy applications
*/
- private HttpResponse promoteApplicationDeployment(String tenantName, String applicationName, String environmentName, String regionName) {
+ private HttpResponse promoteApplicationDeployment(String tenantName, String applicationName, String environmentName, String regionName, String instanceName, HttpRequest request) {
try {
+ applicationInstanceAuthorizer.throwIfUnauthorized(authorizer.getPrincipal(request),
+ Environment.from(environmentName),
+ getTenantOrThrow(tenantName),
+ ApplicationName.from(applicationName));
+
ApplicationChefEnvironment chefEnvironment = new ApplicationChefEnvironment(controller.system());
String sourceEnvironment = chefEnvironment.applicationSourceEnvironment(TenantName.from(tenantName), ApplicationName.from(applicationName));
String targetEnvironment = chefEnvironment.applicationTargetEnvironment(TenantName.from(tenantName), ApplicationName.from(applicationName), Environment.from(environmentName), RegionName.from(regionName));
@@ -866,13 +897,9 @@ public class ApplicationApiHandler extends LoggingRequestHandler {
}
}
- private Optional<UserId> userFrom(HttpRequest request) {
- return authorizer.getPrincipalIfAny(request)
- .map(AthenzPrincipal::getIdentity)
- .filter(AthenzUser.class::isInstance)
- .map(AthenzUser.class::cast)
- .map(AthenzUser::getName)
- .map(UserId::new);
+ private Tenant getTenantOrThrow(String tenantName) {
+ return controller.tenants().tenant(new TenantId(tenantName))
+ .orElseThrow(() -> new NotExistsException(new TenantId(tenantName)));
}
private void toSlime(Cursor object, Tenant tenant, HttpRequest request, boolean listApplications) {
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/DeployAuthorizer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationInstanceAuthorizer.java
index af519439600..283d700c2bd 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/DeployAuthorizer.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationInstanceAuthorizer.java
@@ -2,7 +2,7 @@
package com.yahoo.vespa.hosted.controller.restapi.application;
import com.yahoo.config.application.api.DeploymentSpec;
-import com.yahoo.config.provision.ApplicationId;
+import com.yahoo.config.provision.ApplicationName;
import com.yahoo.config.provision.Environment;
import com.yahoo.vespa.athenz.api.AthenzDomain;
import com.yahoo.vespa.athenz.api.AthenzPrincipal;
@@ -16,7 +16,6 @@ import com.yahoo.vespa.hosted.controller.application.ApplicationPackage;
import javax.ws.rs.ForbiddenException;
import javax.ws.rs.NotAuthorizedException;
-import java.security.Principal;
import java.util.Objects;
import java.util.Optional;
import java.util.logging.Logger;
@@ -25,56 +24,27 @@ import static com.yahoo.vespa.hosted.controller.api.integration.athenz.HostedAth
import static com.yahoo.vespa.hosted.controller.restapi.application.Authorizer.environmentRequiresAuthorization;
/**
+ * Validates that principal is allowed to perform a mutating operation on an application instance.
+ *
* @author bjorncs
* @author gjoranv
*/
-public class DeployAuthorizer {
+public class ApplicationInstanceAuthorizer {
- private static final Logger log = Logger.getLogger(DeployAuthorizer.class.getName());
+ private static final Logger log = Logger.getLogger(ApplicationInstanceAuthorizer.class.getName());
private final ZoneRegistry zoneRegistry;
private final AthenzClientFactory athenzClientFactory;
- public DeployAuthorizer(ZoneRegistry zoneRegistry, AthenzClientFactory athenzClientFactory) {
+ public ApplicationInstanceAuthorizer(ZoneRegistry zoneRegistry, AthenzClientFactory athenzClientFactory) {
this.zoneRegistry = zoneRegistry;
this.athenzClientFactory = athenzClientFactory;
}
- public void throwIfUnauthorizedForDeploy(Principal principal,
- Environment environment,
- Tenant tenant,
- ApplicationId applicationId,
- Optional<ApplicationPackage> applicationPackage) {
- // Validate that domain in identity configuration (deployment.xml) is same as tenant domain
- applicationPackage.map(ApplicationPackage::deploymentSpec).flatMap(DeploymentSpec::athenzDomain)
- .ifPresent(identityDomain -> {
- AthenzDomain tenantDomain = tenant.getAthensDomain().orElseThrow(() -> new IllegalArgumentException("Identity provider only available to Athenz onboarded tenants"));
- if (! Objects.equals(tenantDomain.getName(), identityDomain.value())) {
- throw new ForbiddenException(
- String.format(
- "Athenz domain in deployment.xml: [%s] must match tenant domain: [%s]",
- identityDomain.value(),
- tenantDomain.getName()
- ));
- }
- });
-
- if (!environmentRequiresAuthorization(environment)) {
- return;
- }
-
- if (principal == null) {
- throw loggedUnauthorizedException("Principal not authenticated!");
- }
-
- if (!(principal instanceof AthenzPrincipal)) {
- throw loggedUnauthorizedException(
- "Principal '%s' of type '%s' is not an Athenz principal, which is required for production deployments.",
- principal.getName(), principal.getClass().getSimpleName());
- }
-
- AthenzPrincipal athenzPrincipal = (AthenzPrincipal) principal;
- AthenzDomain principalDomain = athenzPrincipal.getDomain();
+ public void throwIfUnauthorized(AthenzPrincipal principal,
+ Tenant tenant,
+ ApplicationName application) {
+ AthenzDomain principalDomain = principal.getDomain();
if (!principalDomain.equals(SCREWDRIVER_DOMAIN)) {
throw loggedForbiddenException(
@@ -89,16 +59,47 @@ public class DeployAuthorizer {
// NOTE: no fine-grained deploy authorization for non-Athenz tenants
if (tenant.isAthensTenant()) {
AthenzDomain tenantDomain = tenant.getAthensDomain().get();
- if (!hasDeployAccessToAthenzApplication(athenzPrincipal, tenantDomain, applicationId)) {
+ if (!hasDeployAccessToAthenzApplication(principal, tenantDomain, application)) {
throw loggedForbiddenException(
"Screwdriver principal '%1$s' does not have deploy access to '%2$s'. " +
- "Either the application has not been created at " + zoneRegistry.getDashboardUri() + " or " +
- "'%1$s' is not added to the application's deployer role in Athenz domain '%3$s'.",
- athenzPrincipal.getIdentity().getFullName(), applicationId, tenantDomain.getName());
+ "Either the application has not been created at " + zoneRegistry.getDashboardUri() + " or " +
+ "'%1$s' is not added to the application's deployer role in Athenz domain '%3$s'.",
+ principal.getIdentity().getFullName(), application.value(), tenantDomain.getName());
}
}
}
+ public void throwIfUnauthorized(AthenzPrincipal principal,
+ Environment environment,
+ Tenant tenant,
+ ApplicationName application) {
+ if (!environmentRequiresAuthorization(environment)) {
+ return;
+ }
+ throwIfUnauthorized(principal, tenant, application);
+ }
+
+ public void throwIfUnauthorizedForDeploy(AthenzPrincipal principal,
+ Environment environment,
+ Tenant tenant,
+ ApplicationName application,
+ Optional<ApplicationPackage> applicationPackage) {
+ // Validate that domain in identity configuration (deployment.xml) is same as tenant domain
+ applicationPackage.map(ApplicationPackage::deploymentSpec).flatMap(DeploymentSpec::athenzDomain)
+ .ifPresent(identityDomain -> {
+ AthenzDomain tenantDomain = tenant.getAthensDomain().orElseThrow(() -> new IllegalArgumentException("Identity provider only available to Athenz onboarded tenants"));
+ if (! Objects.equals(tenantDomain.getName(), identityDomain.value())) {
+ throw new ForbiddenException(
+ String.format(
+ "Athenz domain in deployment.xml: [%s] must match tenant domain: [%s]",
+ identityDomain.value(),
+ tenantDomain.getName()
+ ));
+ }
+ });
+ throwIfUnauthorized(principal, environment, tenant, application);
+ }
+
private static ForbiddenException loggedForbiddenException(String message, Object... args) {
String formattedMessage = String.format(message, args);
log.info(formattedMessage);
@@ -111,14 +112,14 @@ public class DeployAuthorizer {
return new NotAuthorizedException(formattedMessage);
}
- private boolean hasDeployAccessToAthenzApplication(AthenzPrincipal principal, AthenzDomain domain, ApplicationId applicationId) {
+ private boolean hasDeployAccessToAthenzApplication(AthenzPrincipal principal, AthenzDomain domain, ApplicationName application) {
try {
return athenzClientFactory.createZmsClientWithServicePrincipal()
.hasApplicationAccess(
principal.getIdentity(),
ApplicationAction.deploy,
domain,
- new com.yahoo.vespa.hosted.controller.api.identifiers.ApplicationId(applicationId.application().value()));
+ new com.yahoo.vespa.hosted.controller.api.identifiers.ApplicationId(application.value()));
} catch (ZmsException e) {
throw loggedForbiddenException(
"Failed to authorize deployment through Athenz. If this problem persists, " +
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/Authorizer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/Authorizer.java
index 06d078e8a36..9d45b9a6e09 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/Authorizer.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/Authorizer.java
@@ -66,18 +66,22 @@ public class Authorizer {
/** Returns the principal or throws forbidden */ // TODO: Avoid REST exceptions
public AthenzPrincipal getPrincipal(HttpRequest request) {
- return getPrincipalIfAny(request).orElseThrow(() -> Authorizer.loggedForbiddenException("User is not authenticated"));
+ return Optional.ofNullable(request.getJDiscRequest().getUserPrincipal())
+ .map(AthenzPrincipal.class::cast)
+ .orElseThrow(() -> loggedForbiddenException("User is not authenticated"));
}
- /** Returns the principal if there is any */
- public Optional<AthenzPrincipal> getPrincipalIfAny(HttpRequest request) {
- return securityContextOf(request)
- .map(SecurityContext::getUserPrincipal)
- .map(AthenzPrincipal.class::cast);
+ public Optional<NToken> getNToken(HttpRequest request) {
+ return getPrincipal(request).getNToken();
}
- public Optional<NToken> getNToken(HttpRequest request) {
- return getPrincipalIfAny(request).flatMap(AthenzPrincipal::getNToken);
+ public Optional<UserId> getUserId(HttpRequest request) {
+ return Optional.of(getPrincipal(request))
+ .map(AthenzPrincipal::getIdentity)
+ .filter(AthenzUser.class::isInstance)
+ .map(AthenzUser.class::cast)
+ .map(AthenzUser::getName)
+ .map(UserId::new);
}
public boolean isSuperUser(HttpRequest request) {
@@ -147,6 +151,8 @@ public class Authorizer {
return securityContext.get().isUserInRole(Authorizer.VESPA_HOSTED_ADMIN_ROLE);
}
+ @Deprecated
+ // TODO: Remove once Bouncer filter is no longer needed
protected Optional<SecurityContext> securityContextOf(HttpRequest request) {
return Optional.ofNullable((SecurityContext)request.getJDiscRequest().context().get(ContextAttributes.SECURITY_CONTEXT_ATTRIBUTE));
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/securitycontext/CreateSecurityContextFilter.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/securitycontext/CreateSecurityContextFilter.java
index 6073307bafa..5fc15f4baa6 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/securitycontext/CreateSecurityContextFilter.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/securitycontext/CreateSecurityContextFilter.java
@@ -20,6 +20,8 @@ import java.security.Principal;
@After("BouncerFilter")
@Provides("SecurityContext")
@SuppressWarnings("unused") // Injected
+@Deprecated
+// TODO Remove once Bouncer filter is gone
public class CreateSecurityContextFilter implements SecurityRequestFilter {
@Override
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/securitycontext/PropagateSecurityContextFilter.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/securitycontext/PropagateSecurityContextFilter.java
index 17c86e89362..23f94f2fc21 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/securitycontext/PropagateSecurityContextFilter.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/securitycontext/PropagateSecurityContextFilter.java
@@ -18,6 +18,8 @@ import java.io.IOException;
*/
@PreMatching
@Provider
+// TODO Remove once Bouncer filter is gone
+@Deprecated
public class PropagateSecurityContextFilter implements ContainerRequestFilter {
@Override
public void filter(ContainerRequestContext requestContext) throws IOException {
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java
index 61a4a883904..e6fe7531fdc 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java
@@ -197,7 +197,8 @@ public class ApplicationApiTest extends ControllerContainerTest {
.data(createApplicationDeployData(applicationPackage, Optional.of(screwdriverProjectId)))
.screwdriverIdentity(SCREWDRIVER_ID),
new File("deploy-result.json"));
- tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/test/region/us-east-1/instance/default", DELETE),
+ tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/test/region/us-east-1/instance/default", DELETE)
+ .screwdriverIdentity(SCREWDRIVER_ID),
"Deactivated tenant/tenant1/application/application1/environment/test/region/us-east-1/instance/default");
controllerTester.notifyJobCompletion(id, screwdriverProjectId, true, DeploymentJobs.JobType.systemTest); // Called through the separate screwdriver/v1 API
@@ -206,7 +207,8 @@ public class ApplicationApiTest extends ControllerContainerTest {
.data(createApplicationDeployData(applicationPackage, Optional.of(screwdriverProjectId)))
.screwdriverIdentity(SCREWDRIVER_ID),
new File("deploy-result.json"));
- tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/staging/region/us-east-3/instance/default", DELETE),
+ tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/staging/region/us-east-3/instance/default", DELETE)
+ .screwdriverIdentity(SCREWDRIVER_ID),
"Deactivated tenant/tenant1/application/application1/environment/staging/region/us-east-3/instance/default");
controllerTester.notifyJobCompletion(id, screwdriverProjectId, true, DeploymentJobs.JobType.stagingTest);
@@ -252,10 +254,12 @@ public class ApplicationApiTest extends ControllerContainerTest {
// POST a 'restart application' command
- tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/prod/region/corp-us-east-1/instance/default/restart", POST),
+ tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/prod/region/corp-us-east-1/instance/default/restart", POST)
+ .screwdriverIdentity(SCREWDRIVER_ID),
"Requested restart of tenant/tenant1/application/application1/environment/prod/region/corp-us-east-1/instance/default");
// POST a 'restart application' command with a host filter (other filters not supported yet)
- tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/prod/region/corp-us-east-1/instance/default/restart?hostname=host1", POST),
+ tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/prod/region/corp-us-east-1/instance/default/restart?hostname=host1", POST)
+ .screwdriverIdentity(SCREWDRIVER_ID),
"Requested restart of tenant/tenant1/application/application1/environment/prod/region/corp-us-east-1/instance/default");
// POST a 'log' command
tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/prod/region/corp-us-east-1/instance/default/log", POST),
@@ -275,14 +279,17 @@ public class ApplicationApiTest extends ControllerContainerTest {
new File("delete-with-active-deployments.json"), 400);
// DELETE (deactivate) a deployment - dev
- tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/dev/region/us-west-1/instance/default", DELETE),
+ tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/dev/region/us-west-1/instance/default", DELETE)
+ .screwdriverIdentity(SCREWDRIVER_ID),
"Deactivated tenant/tenant1/application/application1/environment/dev/region/us-west-1/instance/default");
// DELETE (deactivate) a deployment - prod
- tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/prod/region/corp-us-east-1/instance/default", DELETE),
+ tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/prod/region/corp-us-east-1/instance/default", DELETE)
+ .screwdriverIdentity(SCREWDRIVER_ID),
"Deactivated tenant/tenant1/application/application1/environment/prod/region/corp-us-east-1/instance/default");
// DELETE (deactivate) a deployment is idempotent
- tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/prod/region/corp-us-east-1/instance/default", DELETE),
+ tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/prod/region/corp-us-east-1/instance/default", DELETE)
+ .screwdriverIdentity(SCREWDRIVER_ID),
"Deactivated tenant/tenant1/application/application1/environment/prod/region/corp-us-east-1/instance/default");
// PUT (create) the authenticated user
@@ -315,9 +322,11 @@ public class ApplicationApiTest extends ControllerContainerTest {
.data("{\"reason\":\"because i can\"}"),
new File("global-rotation-delete.json"));
- tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/promote", POST),
+ tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/promote", POST)
+ .screwdriverIdentity(SCREWDRIVER_ID),
"{\"message\":\"Successfully copied environment hosted-verified-prod to hosted-instance_tenant1_application1_placeholder_component_default\"}");
- tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/prod/region/us-west-1/instance/default/promote", POST),
+ tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/prod/region/us-west-1/instance/default/promote", POST)
+ .screwdriverIdentity(SCREWDRIVER_ID),
"{\"message\":\"Successfully copied environment hosted-instance_tenant1_application1_placeholder_component_default to hosted-instance_tenant1_application1_us-west-1_prod_default\"}");
// DELETE an application
@@ -816,7 +825,8 @@ public class ApplicationApiTest extends ControllerContainerTest {
.data(deployData)
.screwdriverIdentity(SCREWDRIVER_ID),
new File("deploy-result.json"));
- tester.assertResponse(request(testPath, DELETE),
+ tester.assertResponse(request(testPath, DELETE)
+ .screwdriverIdentity(SCREWDRIVER_ID),
"Deactivated " + testPath.replaceFirst("/application/v4/", ""));
controllerTester.notifyJobCompletion(application, projectId, true, DeploymentJobs.JobType.systemTest);
@@ -827,7 +837,8 @@ public class ApplicationApiTest extends ControllerContainerTest {
.data(deployData)
.screwdriverIdentity(SCREWDRIVER_ID),
new File("deploy-result.json"));
- tester.assertResponse(request(stagingPath, DELETE),
+ tester.assertResponse(request(stagingPath, DELETE)
+ .screwdriverIdentity(SCREWDRIVER_ID),
"Deactivated " + stagingPath.replaceFirst("/application/v4/", ""));
controllerTester.notifyJobCompletion(application, projectId, true, DeploymentJobs.JobType.stagingTest);
}
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/MockAuthorizer.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/MockAuthorizer.java
index d0f5f4dbdb9..f2fc4b12096 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/MockAuthorizer.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/MockAuthorizer.java
@@ -11,6 +11,7 @@ import com.yahoo.vespa.hosted.controller.TestIdentities;
import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzClientFactory;
import com.yahoo.vespa.hosted.controller.api.integration.entity.EntityService;
+import javax.ws.rs.ForbiddenException;
import javax.ws.rs.core.SecurityContext;
import java.security.Principal;
import java.util.Optional;
@@ -31,14 +32,15 @@ public class MockAuthorizer extends Authorizer {
/** Returns a principal given by the request parameters 'domain' and 'user' */
@Override
- public Optional<AthenzPrincipal> getPrincipalIfAny(HttpRequest request) {
+ public AthenzPrincipal getPrincipal(HttpRequest request) {
String domain = request.getHeader("Athenz-Identity-Domain");
String name = request.getHeader("Athenz-Identity-Name");
- if (domain == null || name == null) return Optional.empty();
- return Optional.of(
- new AthenzPrincipal(
- AthenzIdentities.from(new AthenzDomain(domain), name),
- new NToken("dummy")));
+ if (domain == null || name == null) {
+ throw new ForbiddenException("User is not authenticated");
+ }
+ return new AthenzPrincipal(
+ AthenzIdentities.from(new AthenzDomain(domain), name),
+ new NToken("dummy"));
}
/** Returns the hardcoded NToken of {@link TestIdentities#userId} */
@@ -50,9 +52,9 @@ public class MockAuthorizer extends Authorizer {
@Override
protected Optional<SecurityContext> securityContextOf(HttpRequest request) {
- return getPrincipalIfAny(request).map(MockSecurityContext::new);
+ return Optional.of(new MockSecurityContext(getPrincipal(request)));
}
-
+
private static final class MockSecurityContext implements SecurityContext {
private final Principal principal;
diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/noderepository/NodeRepositoryImplTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/noderepository/NodeRepositoryImplTest.java
index 85d92dbee25..949b4ccdf78 100644
--- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/noderepository/NodeRepositoryImplTest.java
+++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/noderepository/NodeRepositoryImplTest.java
@@ -55,10 +55,22 @@ public class NodeRepositoryImplTest {
*/
@Before
public void startContainer() throws Exception {
- final int port = findRandomOpenPort();
- requestExecutor = ConfigServerHttpRequestExecutor.create(
- Collections.singleton(URI.create("http://127.0.0.1:" + port)), Optional.empty(), Optional.empty(), Optional.empty());
- container = JDisc.fromServicesXml(ContainerConfig.servicesXmlV2(port), Networking.enable);
+ Exception lastException = null;
+
+ // This tries to bind a random open port for the node-repo mock, which is a race condition, so try
+ // a few times before giving up
+ for (int i = 0; i < 3; i++) {
+ try {
+ final int port = findRandomOpenPort();
+ container = JDisc.fromServicesXml(ContainerConfig.servicesXmlV2(port), Networking.enable);
+ requestExecutor = ConfigServerHttpRequestExecutor.create(
+ Collections.singleton(URI.create("http://127.0.0.1:" + port)), Optional.empty(), Optional.empty(), Optional.empty());
+ return;
+ } catch (RuntimeException e) {
+ lastException = e;
+ }
+ }
+ throw new RuntimeException("Failed to bind a port in three attempts, giving up", lastException);
}
private void waitForJdiscContainerToServe() throws InterruptedException {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 59d2d95b879..e5a9e6a5ef1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -33,7 +33,7 @@ public abstract class Value {
/** Returns this as a tensor value */
public abstract Tensor asTensor();
- /** A utility method for wrapping a sdouble in a rank 0 tensor */
+ /** A utility method for wrapping a double in a rank 0 tensor */
protected Tensor doubleAsTensor(double value) {
return Tensor.Builder.of(TensorType.empty).cell(TensorAddress.of(), value).build();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index 55782c36d18..ef82045e771 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -156,6 +156,12 @@ class OperationMapper {
private static Optional<TypedTensorFunction> constant(TensorFlowImporter.Parameters params) {
Tensor value = AttrValueConverter.toVespaTensor(params.node(), "value");
+ if (value.type().rank() == 0) {
+ TypedTensorFunction output = new TypedTensorFunction(value.type(),
+ new TensorFunctionNode.TensorFunctionExpressionNode(
+ new ConstantNode(new DoubleValue(value.asDouble()))));
+ return Optional.of(output);
+ }
return createConstant(params, value);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index e4c381972e9..ec6af4bb413 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -117,11 +117,7 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public Tensor evaluate(EvaluationContext context) {
- Value result = expression.evaluate((Context)context);
- if ( ! ( result instanceof TensorValue))
- throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " +
- "but this returns " + result + ", not a tensor");
- return result.asTensor();
+ return expression.evaluate((Context)context).asTensor();
}
@Override
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 6c7643b37b3..e9030cf5852 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -294,6 +294,7 @@ public class EvaluationTestCase {
"tensor0 != tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }");
tester.assertEvaluates("{ {x:0}:1, {x:1}:0 }",
"tensor0 in [1,2,3]", "{ {x:0}:3, {x:1}:7 }");
+ tester.assertEvaluates("{ {x:0}:0.1 }", "join(tensor0, 0.1, f(x,y) (x*y))", "{ {x:0}:1 }");
// TODO
// argmax