Podemos fazer isso de duas maneiras simples . O primeiro é fácil de codificar, fácil de entender e razoavelmente rápido. O segundo é um pouco mais complicado, mas muito mais eficiente para esse tamanho de problema do que o primeiro método ou outras abordagens mencionadas aqui.
Método 1 : Rápido e sujo.
Para obter uma única observação da distribuição de probabilidade de cada linha, podemos simplesmente fazer o seguinte.
# Q is the cumulative distribution of each row.
Q <- t(apply(P,1,cumsum))
# Get a sample with one observation from the distribution of each row.
X <- rowSums(runif(N) > Q) + 1
Isso produz a distribuição cumulativa de cada linha de e, em seguida, coleta uma observação de cada distribuição. Observe que, se pudermos reutilizar , podemos calcular uma vez e armazená-lo para uso posterior. No entanto, a pergunta precisa de algo que funcione para um diferente a cada iteração.P Q PP PQP
Se você precisar de várias ( ) observações de cada linha, substitua a última linha pela seguinte.n
# Returns an N x n matrix
X <- replicate(n, rowSums(runif(N) > Q)+1)
Em geral, essa não é uma maneira extremamente eficiente de fazer isso, mas tira proveito das Rcapacidades de vetorização, que geralmente são o principal determinante da velocidade de execução. Também é simples de entender.
Método 2 : concatenando os cdfs.
Suponha que tivéssemos uma função que pegou dois vetores, o segundo dos quais foi classificado em ordem monotônica não decrescente e encontrou o índice no segundo vetor do maior limite inferior de cada elemento no primeiro. Então, poderíamos usar esta função e um truque liso: Basta criar a soma cumulativa dos cdfs de todas as linhas. Isso fornece um vetor monotonicamente crescente com elementos no intervalo .[0,N]
Aqui está o código.
i <- 0:(N-1)
# Cumulative function of the cdfs of each row of P.
Q <- cumsum(t(P))
# Find the interval and then back adjust
findInterval(runif(N)+i, Q)-i*K+1
Observe o que a última linha faz, ela cria variáveis aleatórias distribuídas em e depois chama para encontrar o índice do maior limite inferior de cada entrada . Assim, esta diz-nos que o primeiro elemento de vai ser encontrado entre o índice de um e o índice , a segunda vai ser encontrado entre o índice e , etc, cada um de acordo com a distribuição da linha correspondente de . Então, precisamos voltar a transformar para obter cada um dos índices de volta no intervalo .K K + 1 2 K P { 1 , … , K }(0,1),(1,2),…,(N−1,N)findIntervalrunif(N)+iKK+12KP{1,…,K}
Por findIntervalser rápido, tanto em termos de algoritmo quanto de implementação, esse método acaba sendo extremamente eficiente.
Uma referência
No meu laptop antigo (MacBook Pro, 2,66 GHz, 8GB RAM), tentei isso com e e gerando 5000 amostras do tamanho , exatamente como sugerido na pergunta atualizada, para um total de 50 milhões de variáveis aleatórias .K = 100 NN=10000K=100N
O código do método 1 levou quase exatamente 15 minutos para ser executado, ou cerca de 55K variáveis aleatórias por segundo. O código do método 2 levou cerca de quatro minutos e meio para ser executado, ou cerca de 183 mil variáveis aleatórias por segundo.
Aqui está o código para a reprodutibilidade. (Observe que, conforme indicado em um comentário, é recalculado para cada uma das 5000 iterações para simular a situação do OP.)Q
# Benchmark code
N <- 10000
K <- 100
set.seed(17)
P <- matrix(runif(N*K),N,K)
P <- P / rowSums(P)
method.one <- function(P)
{
Q <- t(apply(P,1,cumsum))
X <- rowSums(runif(nrow(P)) > Q) + 1
}
method.two <- function(P)
{
n <- nrow(P)
i <- 0:(n-1)
Q <- cumsum(t(P))
findInterval(runif(n)+i, Q)-i*ncol(P)+1
}
Aqui está a saída.
# Method 1: Timing
> system.time(replicate(5e3, method.one(P)))
user system elapsed
691.693 195.812 899.246
# Method 2: Timing
> system.time(replicate(5e3, method.two(P)))
user system elapsed
182.325 82.430 273.021
Postscript : Observando o código findInterval, podemos ver que ele faz algumas verificações na entrada para ver se há NAentradas ou se o segundo argumento não está classificado. Portanto, se quiséssemos extrair mais desempenho disso, poderíamos criar nossa própria versão modificada, findIntervalque remove essas verificações que são desnecessárias no nosso caso.