Eu uso o Tensorflow, mas estou escrevendo documentação para usuários que normalmente variam entre as estruturas de aprendizado profundo .
Ao trabalhar com conjuntos de dados que não se encaixam no sistema de arquivos local (TB +), colho dados de um repositório de dados remoto e gravo amostras localmente em um tfrecords
formato padrão do Tensorflow .
Durante a primeira época do treinamento, apenas amostrarei alguns valores; portanto, uma época de dados locais é muito pequena, eu treino nela. Na época 2 , re-examino quais arquivos de dados foram produzidos pelos meus subprocessos de amostragem (agora mais) e treino o conjunto expandido de arquivos de dados locais para a próxima época. Repita o processo a cada época. Dessa maneira, construo um cache local de amostras e posso despejar amostras mais antigas conforme preencho o armazenamento local. O cache de amostras locais cresce aproximadamente no momento em que o modelo precisa mais da variação (na última parte do treinamento).
No Python / Tensorflow, é crucial que eu não desserialize os dados no processo de loop de treinamento do Python, porque o GIL do Python não pode suportar as taxas de transferência de dados (300-600 MB / s, os dados são científicos não compactáveis) e, portanto, o desempenho da GPU sofre quando o GIL do Python não pode atender rapidamente ao ciclo de treinamento.
Gravar as amostras em um tfrecords
arquivo a partir de subprocessos (multiprocessamento python) permite que o nativo do tensorflow TFRecordsDataset
faça a desserialização fora do Python e, assim, contornamos os problemas do GIL do Python, e posso saturar uma GPU com altas taxas de dados de IO.
Gostaria de saber como resolveria esse problema em Pytorch. Estou escrevendo sobre a estratégia de amostragem que está sendo usada e quero fornecer recomendações específicas aos usuários do Tensorflow e do PyTorch, mas não conheço o ecossistema de pré-processamento do PyTorch o suficiente para escrever com detalhes suficientes.
Nota lateral: a única solução puramente baseada em Python para suportar essas taxas de transferência de dados pode vir no Python 3.8 com memória compartilhada e multiprocessamento do System V, mas ainda não tentei isso, pois o suporte a ela não é suficiente (em breve será ) As soluções de multiprocessamento existentes não são suficientes porque exigem desserialização no processo do loop de treinamento e, portanto, bloqueiam o GIL durante a desserialização a altas taxas de IO.
DataLoader
como na minha resposta.