Como ler de um conjunto de dados de IO alto no pytorch que cresce de época para época


8

Eu uso o Tensorflow, mas estou escrevendo documentação para usuários que normalmente variam entre as estruturas de aprendizado profundo .

Ao trabalhar com conjuntos de dados que não se encaixam no sistema de arquivos local (TB +), colho dados de um repositório de dados remoto e gravo amostras localmente em um tfrecordsformato padrão do Tensorflow .

Durante a primeira época do treinamento, apenas amostrarei alguns valores; portanto, uma época de dados locais é muito pequena, eu treino nela. Na época 2 , re-examino quais arquivos de dados foram produzidos pelos meus subprocessos de amostragem (agora mais) e treino o conjunto expandido de arquivos de dados locais para a próxima época. Repita o processo a cada época. Dessa maneira, construo um cache local de amostras e posso despejar amostras mais antigas conforme preencho o armazenamento local. O cache de amostras locais cresce aproximadamente no momento em que o modelo precisa mais da variação (na última parte do treinamento).

No Python / Tensorflow, é crucial que eu não desserialize os dados no processo de loop de treinamento do Python, porque o GIL do Python não pode suportar as taxas de transferência de dados (300-600 MB / s, os dados são científicos não compactáveis) e, portanto, o desempenho da GPU sofre quando o GIL do Python não pode atender rapidamente ao ciclo de treinamento.

Gravar as amostras em um tfrecordsarquivo a partir de subprocessos (multiprocessamento python) permite que o nativo do tensorflow TFRecordsDatasetfaça a desserialização fora do Python e, assim, contornamos os problemas do GIL do Python, e posso saturar uma GPU com altas taxas de dados de IO.

Gostaria de saber como resolveria esse problema em Pytorch. Estou escrevendo sobre a estratégia de amostragem que está sendo usada e quero fornecer recomendações específicas aos usuários do Tensorflow e do PyTorch, mas não conheço o ecossistema de pré-processamento do PyTorch o suficiente para escrever com detalhes suficientes.

Nota lateral: a única solução puramente baseada em Python para suportar essas taxas de transferência de dados pode vir no Python 3.8 com memória compartilhada e multiprocessamento do System V, mas ainda não tentei isso, pois o suporte a ela não é suficiente (em breve será ) As soluções de multiprocessamento existentes não são suficientes porque exigem desserialização no processo do loop de treinamento e, portanto, bloqueiam o GIL durante a desserialização a altas taxas de IO.


2
Como você sabe que as taxas de transferência de dados sofrem com o Python GIL? Que eu saiba, é a operação vinculada à CPU que é afetada pelo GIL na maioria dos casos, não a operação vinculada à E / S.
bombs

Nos meus testes, apenas a desserialização entre os processos Python nas taxas de dados mais rápidas que posso alcançar mantém o processo de destino com 100% de utilização da CPU. Eu tentei muitas abordagens, assíncrona, multiprocessamento e até leituras diretas de soquete. No caso de leituras diretas de soquete, posso obter 4 GB / s em todos os processos e, no momento em que tento juntar as cadeias binárias, caio para 2 GB / s, e qualquer coisa mais complexa me leva a uma taxa máxima de xfer de 1 GB / s. Isso é tudo com o processo de destino, utilizando totalmente o núcleo e bloqueando o GIL.
David Parks

Observe que isso não é realmente um problema com grandes conjuntos de dados comuns, como o imagenet, porque a IO necessária para mover JPEGs compactados em grandes redes neurais é pequena em comparação com o que o treinamento de dados científicos não compactados exige em pequenas redes.
David Parks

1
uma junção de string é categorizada em uma operação vinculada à CPU e pode facilmente exigir uma capacidade de 100% da CPU sem utilizar a capacidade de E / S da máquina. Portanto, não é uma evidência de que um GIL restrinja a taxa de transferência de E / S.
bombs

2
Essas operações triviais não reivindicam o GIL do processo principal se os dados forem carregados DataLoadercomo na minha resposta.
bombas

Respostas:


7

Na verdade, você pode facilmente desserializar dados em um subprocesso usando torch.utils.data.DataLoader. Ao definir o num_workersargumento como 1 ou um valor maior, você pode gerar subprocessos com seus próprios intérpretes de python e GILs.

loader = torch.utils.data.DataLoader(your_dataset, num_workers=n, **kwargs)
for epoch in range(epochs):
    for batch_idx, data in enumerate(loader):
         # loader in the main process does not claim GIL at this point

A Dataloaderrequer a torch.utils.data.Datasetpara obter dados. Pode não ser um trabalho trivial implementar uma subclasse adequada no seu caso. Caso seja necessário recriar uma Datasetinstância para cada época, você pode fazer algo assim.

for epcoh in range(epochs):
    dset = get_new_dataset()
    loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
    for batch_idx, data in enumerate(loader):
        # Do training

ou melhor ainda

dset = get_new_dataset()
loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)

for epcoh in range(epochs):
    last_batch_idx =  (len(dset)-1) // loader.batch_size
    for batch_idx, data in enumerate(loader):
        # Prepare next loader in advance to avoid blocking
        if batch_idx == last_batch_idx:
            dset = get_new_dataset()
            loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
        # Do training

Como uma observação lateral, observe que a operação vinculada à CPU é afetada pelo GIL na maioria dos casos, não a operação vinculada a E / S, ou seja, threadingservirá para qualquer operação pesada de E / S e você nem precisa subprocess. Para mais informações, consulte esta pergunta e este artigo da Wikipedia .


Apenas para confirmar, os torch.utils.data.DataLoaderdados são colocados na GPU pelos subprocessos ou ele está tentando usar o multiprocessamento do python para movê-los para o processo do loop de treinamento? Descobri que apenas a desserialização de um processo para outro a taxas de dados próximas a 1 GB / s é> 1 núcleo completo de trabalho, daí os problemas de GIL que encontrei ao tentar essa abordagem no TF. Mas se torch.utils.data.DataLoaderestiver movendo dados para a GPU de uma maneira que não exija desserialização do Python, tudo estará bem. Eu só quero confirmar esse entendimento.
David Parks

@DavidParks Que função específica você usa quando está testando a desserialização de um processo para outro? Parece que o processo de desserialização envolve uma operação vinculada à CPU, daí os problemas de GIL.
bombs

Eu tentei multiprocessamento (muito lento), Pipes (melhor) e leituras brutas de soquete (melhor). Nada disso funciona quando as taxas de E / S são uma fração substancial de GB / s, apenas mover muitos dados exige mais de um núcleo e, portanto, as soluções Python (anteriores à memória compartilhada 3.8 e System V) se desintegram no Tensorflow. É por isso que escrevo nos arquivos tfrecords e deixo o tensorflow fazer a desserialização fora do Python. O TF não bloqueia o Python GIL e paralela as operações; portanto, meu processo principal usa 600% da CPU, enquanto o Python GIL permanece ocioso e livre para atender ao ciclo de treinamento.
David Parks

@DavidParks Quero dizer, que tipo de função ou biblioteca de desserialização você usa? (não biblioteca de comunicação entre processos). torch.utils.data.DataLoaderpode utilizar facilmente 600% da CPU ou mais, e o processo principal não precisa de muita energia da CPU na maioria dos casos quando o treinamento é principalmente computação da GPU (quando o treinamento é principalmente computação da CPU, ainda não há problema porque a operação da matriz do pytorch pode facilmente utilizar vários CPUs).
bombas

Basta usar pickle para desserializar os processos python e, em seguida, uma função de gerador python para alimentar amostras no ecossistema TensorFlow. Essa é a abordagem que falha em mim. Depois que os dados estão no ecossistema TensorFlow, eles são colocados na GPU e o treinamento é outra história. O TF não fornece uma maneira para os subprocessos python fornecerem dados ao TF, você só tem algumas opções e os dados formatados em tfrecords (formato de Buffers de Protocolo) são os mais lógicos. Parece que pode ser mais fácil no PyTorch, então alguns usuários do PyTorch aqui o validarão.
David Parks
Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.