Sampling words by frequency

This post illustrates the sampling recipe mentioned in the lab on the skip-gram model.
Author

Marco Kuhlmann

Published

January 29, 2024

We want to sample words from a vocabulary with a probability that is proportional to their counts (absolute frequencies) in some given text. That is, if we have two words \(w_1\) and \(w_2\), where \(w_2\) appears \(k\) times as often as \(w_1\), then the expected number of times we sample \(w_1\) should be \(k\) times higher than the expected number of times we sample \(w_2\).

Sampling recipe

Imagine all the words in the vocabulary covering a line marked with numbers between 0 and the sum of all word frequencies, where each word covers an interval whose size equals its frequency. To sample a word, we choose a random point on that line, and return that word whose interval includes this chosen point. In doing so, we will sample words with a probability that is proportional to its frequency.

Example

We illustrate the sampling recipe with a concrete example.

import numpy as np
import torch

Here is a list of counts for words in a ten-word vocabulary:

counts = np.array([14507, 5014, 4602, 4529, 4000, 3219, 3010, 2958, 2225, 1271])

To implement the sampling recipe, we need the cumulative sums of these counts. We can get them with the function torch.cumsum().

cumulative_sums = torch.cumsum(torch.from_numpy(counts), dim=0)
cumulative_sums
tensor([14507, 19521, 24123, 28652, 32652, 35871, 38881, 41839, 44064, 45335])

To choose a random point on the counts line, we sample a random number between 0 and 1 and multiply it with the sum of all counts, which is the last entry in the list of cumulative sums. Here we choose \(5\) such points.

random_points = torch.rand(5) * cumulative_sums[-1]
random_points
tensor([10790.7236, 36779.1250, 44573.6055,  3949.3364, 25750.5410])

To return the word whose interval on the counts line includes a chosen point, we use the function torch.searchsorted(). This function takes a sorted sequence and tensor of values and finds the indices from the sorted sequence such that, if the corresponding values were inserted before the indices, the order of the corresponding dimension within the sorted sequence would be preserved.

torch.searchsorted(cumulative_sums, random_points)
tensor([0, 6, 9, 0, 3])

Good luck with the lab!