Skip to content

Commit

Permalink
#44 Adding PPTSS sampling.
Browse files Browse the repository at this point in the history
  • Loading branch information
jzonthemtn committed Nov 21, 2024
1 parent 87e168f commit 0fcfeb7
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.eval.judgments.clickmodel.coec.CoecClickModel;
import org.opensearch.eval.judgments.clickmodel.coec.CoecClickModelParameters;
import org.opensearch.eval.samplers.ProbabilityProportionalToSizeParameters;
import org.opensearch.eval.samplers.ProbabilityProportionalToSizeQuerySampler;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule;
import org.opensearch.rest.BaseRestHandler;
Expand Down Expand Up @@ -95,7 +97,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
final String name = request.param("name");
final String description = request.param("description");
final String sampling = request.param("sampling", "pptss");
final int maxQueries = Integer.parseInt(request.param("max_queries", "1000"));
final int querySetSize = Integer.parseInt(request.param("query_set_size", "1000"));

// Create a query set by finding all the unique user_query terms.
if (StringUtils.equalsIgnoreCase(sampling, "none")) {
Expand All @@ -109,22 +111,22 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(QueryBuilders.matchAllQuery());
searchSourceBuilder.from(0);
searchSourceBuilder.size(maxQueries);
searchSourceBuilder.size(querySetSize);

final SearchRequest searchRequest = new SearchRequest(SearchQualityEvaluationPlugin.UBI_QUERIES_INDEX_NAME);
searchRequest.source(searchSourceBuilder);

final SearchResponse searchResponse = client.search(searchRequest).get();

LOGGER.info("Found {} user queries from the ubi_queries index.", searchResponse.getHits().getTotalHits().toString());
// LOGGER.info("Found {} user queries from the ubi_queries index.", searchResponse.getHits().getTotalHits().toString());

final Set<String> queries = new HashSet<>();
for(final SearchHit hit : searchResponse.getHits().getHits()) {
final Map<String, Object> fields = hit.getSourceAsMap();
queries.add(fields.get("user_query").toString());
}

LOGGER.info("Found {} user queries from the ubi_queries index.", queries.size());
// LOGGER.info("Found {} user queries from the ubi_queries index.", queries.size());

// Create the query set and return its ID.
final String querySetId = indexQuerySet(client, name, description, sampling, queries);
Expand All @@ -138,8 +140,12 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
// Create a query set by using PPTSS sampling.
} else if (StringUtils.equalsIgnoreCase(sampling, "pptss")) {

// TODO: Use the PPTSS sampling method - https://opensourceconnections.com/blog/2022/10/13/how-to-succeed-with-explicit-relevance-evaluation-using-probability-proportional-to-size-sampling/
final Collection<String> queries = List.of("computer", "desk", "table", "battery");
final ProbabilityProportionalToSizeParameters parameters = new ProbabilityProportionalToSizeParameters(querySetSize);
final ProbabilityProportionalToSizeQuerySampler sampler = new ProbabilityProportionalToSizeQuerySampler(parameters);

// TODO: Get all queries from the ubi_queries index.

final Collection<String> queries = sampler.sample();

try {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.eval.samplers;

public class ProbabilityProportionalToSizeParameters {

private final int querySetSize;

public ProbabilityProportionalToSizeParameters(int querySetSize) {
this.querySetSize = querySetSize;
}

public int getQuerySetSize() {
return querySetSize;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.eval.samplers;

import org.opensearch.eval.judgments.model.ubi.query.UbiQuery;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;

/**
* An implementation of {@link QuerySampler} that uses PPTSS sampling.
* See https://opensourceconnections.com/blog/2022/10/13/how-to-succeed-with-explicit-relevance-evaluation-using-probability-proportional-to-size-sampling/
* for more information on PPTSS.
*/
public class ProbabilityProportionalToSizeQuerySampler implements QuerySampler {

private final ProbabilityProportionalToSizeParameters parameters;

/**
* Creates a new PPTSS sampler.
* @param parameters The {@link ProbabilityProportionalToSizeParameters parameters} for the sampling.
*/
public ProbabilityProportionalToSizeQuerySampler(final ProbabilityProportionalToSizeParameters parameters) {
this.parameters = parameters;
}

@Override
public String getName() {
return "pptss";
}

@Override
public Collection<String> sample(final Collection<String> userQueries) {

final Map<String, Long> weights = new HashMap<>();

// Increment the weight for each user query.
for(final String userQuery : userQueries) {
weights.merge(userQuery, 1L, Long::sum);
}

// The total number of queries will be used to normalize the weights.
final long countOfQueries = userQueries.size();

// Calculate the normalized weights by dividing by the total number of queries.
final Map<String, Double> normalizedWeights = new HashMap<>();
for(final String userQuery : weights.keySet()) {
normalizedWeights.put(userQuery, weights.get(userQuery) / (double) countOfQueries);
}

// Ensure all normalized weights sum to 1.
final double sumOfNormalizedWeights = normalizedWeights.values().stream().reduce(0.0, Double::sum);
if(sumOfNormalizedWeights != 1.0) {
throw new RuntimeException("Summed normalized weights do not equal 1.0");
}

final Collection<String> querySet = new ArrayList<>();
final Set<Double> randomNumbers = new HashSet<>();

// Generate a random number between 0 and 1 for the size of the query set.
for(int count = 0; count < parameters.getQuerySetSize(); count++) {

// Make a random number not yet used.
double random;
do {
random = Math.random();
} while (randomNumbers.contains(random));
randomNumbers.add(random);

// Find the weight closest to the random weight.
double finalRandom = random;
double nearestWeight = normalizedWeights.values().stream()
.min(Comparator.comparingDouble(i -> Math.abs(i - finalRandom)))
.orElseThrow(() -> new NoSuchElementException("No value present"));

// Find the query having the weight closest to this random number.
for(Map.Entry<String, Double> entry : normalizedWeights.entrySet()) {
if(entry.getValue() == nearestWeight) {
querySet.add(entry.getKey());
break;
}
}

}

return querySet;

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.eval.samplers;

import org.opensearch.eval.judgments.model.ubi.query.UbiQuery;

import java.util.Collection;

/**
* An interface for sampling UBI queries.
*/
public interface QuerySampler {

/**
* Gets the name of the sampler.
* @return The name of the sampler.
*/
String getName();

/**
* Samples the queries.
* @param userQueries A collection of user queries from UBI queries.
* @return A collection of sampled user queries.
*/
Collection<String> sample(Collection<String> userQueries);

}

0 comments on commit 0fcfeb7

Please sign in to comment.