Cross-validation: testando o desempenho de um classificador

Olá pessoal, tudo certo!?

Hoje vamos falar sobre aprendizado de máquina. Não vamos falar sobre as técnicas de classificação, mas sobre as técnicas de verificação de desempenho dos algoritmos.

Dados e características

O exemplo de teste será a classificação de texto baseado no tutorial de classificação de texto do scikit-learn. O código inicial é:

Entre as linhas 12 e 15 definimos os dados que serão usados para o teste de classificação. Os dados são do dataset Twenty Newsgroups que está disponível no scikit-learn e será baixado apenas quando a aplicação for executada pela primeira vez. Estamos fazendo o download apenas das categorias definidas nas linhas 12 e 13, mas é possível usar outras categorias. Na linha 18 definimos o classificador, que é um SVM utilizando a técnica um-contra-um para classificar mais de duas classes. Nas linhas 21 e 22 iniciamos as classes que fazem a extração de características. CountVectorizer quebra o texto em tokens e faz a contagem da bag of words. TfidfTransformer transforma as contagens em frequências usando Term Frequency-Inverse Document Frequency. A função fit_classifier na linha 24 treina o classificador com dos dados train_data e as classes train_labels. As linhas 26 e 27 fazem a extração de características dos dados e a linha 28 efetua o treinamento do classificador. A função predict na linha 30 faz a classificação dos dados test_data. As linhas 31 e 32 fazem a extração de características e a linha 33 retorna as classificações para os dados.

Validação cruzada e a matriz de confusão

A validação cruzada nada mais é do que separar uma parte dos dados para ser usado treinamento e outra parte para ser usada como teste. Para isso, o scikit-learn oferece uma função que faz a separação automaticamente dos dados. Nas linhas 4 e 5 os dados são separados para validação cruzada usando o método. Os dados para treinamento do classificador são armazenados em train_data e train_labels e os dados de teste são armazenados em test_data e train_data. O tamanho dos conjuntos e controlado por test_size, sendo que o valor 0.1 representa 10% de dados para teste. As linhas 7 e 8 treinam o classificador e fazem a classificação dos dados de teste. Na linha 9 é exibida a matriz de confusão que indica o desempenho do classificador. Abaixo está o resultado da execução do código acima.

alt.atheismcomp.graphicssci.medsoc.religion .christianTotal
alt.atheism5011153
comp.graphics0610061
sci.med0355159
soc.religion .christian0205153
Total50675653226

Os cabeçalhos e os totais eu adicionei manualmente. O Total ao final de cada linha indica quantas amostras de determinada classe existiam no dataset de teste. Por exemplo, existem 53 amostras da classe alt.atheism nos dados para treinamento. O Total das colunas indica quantas amostras de teste foram classificados como determinada classe no dataset de teste. Por exemplo, 67 amostras foram classificadas como comp.graphics.

Validação cruzada k-fold

É uma forma de validação cruzada que divide os dados em k subconjuntos (de mesmo tamanho se possível). Em cada rodada, um subconjunto é separado para teste e os k - 1 subconjuntos restantes são usados para treinamento. A ideia é que cada amostra seja usada para teste apenas uma vez e k - 1 vezes para treinamento e ao final o erro médio é computado como uma métrica do desempenho do classificador. O scikit-learn também oferece uma função que facilita a execução da validação cruzada k-fold. As linhas 2 e 3 fazem a extração de características do dataset. A linha 5 efetua a validação cruzada k-fold nos dados, sendo que o valor k é definido por cv que nesse caso é 5. Portanto os dados serão divididos em cinco subconjuntos e testados. A linha 7 exibe os resultados obtidos para cada execução e a linha 8 exibe a média de acertos do classificador nas k execuções. O resultado da execução do código acima é [ 0.96460177 0.97345133 0.96238938 0.9579646 0.97327394] e Accuracy: 0.9663362043479118 +/- 0.01224505099481929.

Alterando o classificador

Agora é possível alterar o classificador e as características usadas para ver se há uma melhora (ou piora) na classificação. Por exemplo, é possível alterar o kernel do SVM de ’linear’ para ‘rbf’ alterando a linha classifier = OneVsOneClassifier(SVC(kernel = ’linear’, random_state = 84)) para classifier = OneVsOneClassifier(SVC(kernel = ‘rbf’, random_state = 84)). O resultado da validação cruzada k-fold é: [ 0.26548673 0.26548673 0.26548673 0.26548673 0.26503341] e Accuracy: 0.2653960620454501 +/- 0.0003626544730670034. A matriz de confusão é:

alt.atheismcomp.graphicssci.medsoc.religion .christianTotal
alt.atheism0005353
comp.graphics0006161
sci.med0005959
soc.religion .christian0005353
Total000226226

É possível perceber que apenas uma mudança no kernel resultou em uma piora do classificador. Por isso é importante testar os parâmetros disponíveis no classificador e as diferentes características que podem ser extraídas dos dados para verificar o desempenho da classificação.

Para hoje é isso. O código final está no GitHub. Abraços e até a próxima :)

Referências

20 Newsgroups.

Bag-of-words model.

Cross-validation: evaluating estimator performance.

Cross validation.

Feature extraction.

Tf-idf :: A Single-Page Tutorial - Information Retrieval and Text Mining.

Tokenization (lexical analysis).

Simple guide to confusion matrix terminology.

sklearn.metrics.confusion_matrix.

sklearn.feature_extraction.text.TfidfTransformer.

sklearn.feature_extraction.text.CountVectorizer.

sklearn.cross_validation.cross_val_score.

Working With Text Data.