Converter a coluna Spark DataFrame em lista python


104

Eu trabalho em um dataframe com duas colunas, mvv e count.

+---+-----+
|mvv|count|
+---+-----+
| 1 |  5  |
| 2 |  9  |
| 3 |  3  |
| 4 |  1  |

Eu gostaria de obter duas listas contendo valores de mvv e valor de contagem. Algo como

mvv = [1,2,3,4]
count = [5,9,3,1]

Então, tentei o seguinte código: A primeira linha deve retornar uma lista python de linhas. Eu queria ver o primeiro valor:

mvv_list = mvv_count_df.select('mvv').collect()
firstvalue = mvv_list[0].getInt(0)

Mas recebo uma mensagem de erro com a segunda linha:

AttributeError: getInt


A partir de faísca 2.3, este código é o mais rápido e menos susceptível de causar exceções OutOfMemory: list(df.select('mvv').toPandas()['mvv']). O Arrow foi integrado ao PySpark, que aumentou toPandassignificativamente. Não use as outras abordagens se estiver usando o Spark 2.3+. Veja minha resposta para mais detalhes de benchmarking.
Poderes em

Respostas:


141

Veja, por que esse jeito que você está fazendo não está funcionando. Primeiro, você está tentando obter um inteiro de um tipo de linha , a saída de sua coleta é assim:

>>> mvv_list = mvv_count_df.select('mvv').collect()
>>> mvv_list[0]
Out: Row(mvv=1)

Se você pegar algo assim:

>>> firstvalue = mvv_list[0].mvv
Out: 1

Você obterá o mvvvalor. Se você quiser todas as informações do array, pode pegar algo assim:

>>> mvv_array = [int(row.mvv) for row in mvv_list.collect()]
>>> mvv_array
Out: [1,2,3,4]

Mas se você tentar o mesmo para a outra coluna, obterá:

>>> mvv_count = [int(row.count) for row in mvv_list.collect()]
Out: TypeError: int() argument must be a string or a number, not 'builtin_function_or_method'

Isso acontece porque counté um método integrado. E a coluna tem o mesmo nome que count. Uma solução alternativa para fazer isso é alterar o nome da coluna de countpara _count:

>>> mvv_list = mvv_list.selectExpr("mvv as mvv", "count as _count")
>>> mvv_count = [int(row._count) for row in mvv_list.collect()]

Mas essa solução alternativa não é necessária, pois você pode acessar a coluna usando a sintaxe do dicionário:

>>> mvv_array = [int(row['mvv']) for row in mvv_list.collect()]
>>> mvv_count = [int(row['count']) for row in mvv_list.collect()]

E finalmente funcionará!


funciona muito bem para a primeira coluna, mas não funciona para a contagem de colunas, acho que por causa de (a função de contagem de centelha)
a.moussa

Você pode adicionar o que está fazendo com a contagem? Adicione aqui nos comentários.
Thiago Baldim

obrigado pela sua resposta. Portanto, esta linha funciona mvv_list = [int (i.mvv) para i em mvv_count.select ('mvv'). collect ()] mas não este count_list = [int (i.count) para i em mvv_count .select ('count'). collect ()] retorna sintaxe inválida
a.moussa

Não é necessário adicionar este select('count')uso desta forma: count_list = [int(i.count) for i in mvv_list.collect()]adicionarei o exemplo à resposta.
Thiago Baldim

1
@ a.moussa [i.['count'] for i in mvv_list.collect()]trabalha para tornar explícito o uso da coluna chamada 'count' e não a countfunção
user989762

103

Seguir uma linha fornece a lista que você deseja.

mvv = mvv_count_df.select("mvv").rdd.flatMap(lambda x: x).collect()

3
Em termos de desempenho, essa solução é muito mais rápida do que sua solução mvv_list = [int (i.mvv) para i em mvv_count.select ('mvv'). Collect ()]
Chanaka Fernando

Esta é de longe a melhor solução que já vi. Obrigado.
hui chen

22

Isso lhe dará todos os elementos como uma lista.

mvv_list = list(
    mvv_count_df.select('mvv').toPandas()['mvv']
)

1
Esta é a solução mais rápida e eficiente para Spark 2.3+. Veja os resultados do benchmarking em minha resposta.
Poderes em

15

O código a seguir irá ajudá-lo

mvv_count_df.select('mvv').rdd.map(lambda row : row[0]).collect()

3
Esta deve ser a resposta aceita. a razão é que você está permanecendo em um contexto de faísca durante todo o processo e então você coleta no final ao invés de sair do contexto de faísca mais cedo, o que pode causar uma coleta maior dependendo do que você está fazendo.
AntiPawn79 de

15

Em meus dados, obtive estes benchmarks:

>>> data.select(col).rdd.flatMap(lambda x: x).collect()

0,52 s

>>> [row[col] for row in data.collect()]

0,271 s

>>> list(data.select(col).toPandas()[col])

0,427 s

O resultado é o mesmo


1
Se você usar em toLocalIteratorvez collectdisso, será ainda mais eficiente em termos de memória[row[col] for row in data.toLocalIterator()]
oglop

5

Se você receber o erro abaixo:

AttributeError: o objeto 'list' não tem nenhum atributo 'coletar'

Este código resolverá seus problemas:

mvv_list = mvv_count_df.select('mvv').collect()

mvv_array = [int(i.mvv) for i in mvv_list]

Também recebi esse erro e esta solução resolveu o problema. Mas por que recebi o erro? (Muitos outros parecem não entender isso!)
bikashg

1

Fiz uma análise de benchmarking e list(mvv_count_df.select('mvv').toPandas()['mvv'])é o método mais rápido. Estou muito surpreso.

Eu executei as diferentes abordagens em conjuntos de dados de 100 mil / 100 milhões de linhas usando um cluster i3.xlarge de 5 nós (cada nó tem 30,5 GBs de RAM e 4 núcleos) com Spark 2.4.5. Os dados foram distribuídos uniformemente em 20 arquivos Parquet compactados com uma única coluna.

Aqui estão os resultados do benchmarking (tempos de execução em segundos):

+-------------------------------------------------------------+---------+-------------+
|                          Code                               | 100,000 | 100,000,000 |
+-------------------------------------------------------------+---------+-------------+
| df.select("col_name").rdd.flatMap(lambda x: x).collect()    |     0.4 | 55.3        |
| list(df.select('col_name').toPandas()['col_name'])          |     0.4 | 17.5        |
| df.select('col_name').rdd.map(lambda row : row[0]).collect()|     0.9 | 69          |
| [row[0] for row in df.select('col_name').collect()]         |     1.0 | OOM         |
| [r[0] for r in mid_df.select('col_name').toLocalIterator()] |     1.2 | *           |
+-------------------------------------------------------------+---------+-------------+

* cancelled after 800 seconds

Regras de ouro a serem seguidas ao coletar dados no nó do driver:

  • Tente resolver o problema com outras abordagens. Coletar dados para o nó do driver é caro, não aproveita o poder do cluster Spark e deve ser evitado sempre que possível.
  • Colete o mínimo de linhas possível. Agregue, desduplique, filtre e remova colunas antes de coletar os dados. Envie o mínimo de dados possível ao nó do driver.

toPandas foi significativamente melhorado no Spark 2.3 . Provavelmente não é a melhor abordagem se você estiver usando uma versão do Spark anterior à 2.3.

Veja aqui mais detalhes / resultados de benchmarking.


1

Uma possível solução é usar a collect_list()função de pyspark.sql.functions. Isso agregará todos os valores da coluna em uma matriz pyspark que é convertida em uma lista Python quando coletada:

mvv_list   = df.select(collect_list("mvv")).collect()[0][0]
count_list = df.select(collect_list("count")).collect()[0][0] 
Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.