Cross-validation: testando o desempenho de um classificador

Olá pessoal, tudo certo!?

Hoje vamos falar sobre aprendizado de máquina. Não vamos falar sobre as técnicas de classificação, mas sobre as técnicas de verificação de desempenho dos algoritmos.

Dados e características

O exemplo de teste será a classificação de texto baseado no tutorial de classificação de texto do scikit-learn. O código inicial é:

Entre as linhas 12 e 15 definimos os dados que serão usados para o teste de classificação. Os dados são do dataset Twenty Newsgroups que está disponível no scikit-learn e será baixado apenas quando a aplicação for executada pela primeira vez. Estamos fazendo o download apenas das categorias definidas nas linhas 12 e 13, mas é possível usar outras categorias. Na linha 18 definimos o classificador, que é um SVM utilizando a técnica um-contra-um para classificar mais de duas classes. Nas linhas 21 e 22 iniciamos as classes que fazem a extração de características. CountVectorizer quebra o texto em tokens e faz a contagem da bag of words. TfidfTransformer transforma as contagens em frequências usando Term Frequency-Inverse Document Frequency. A função _fitclassifier na linha 24 treina o classificador com dos dados _traindata e as classes _trainlabels. As linhas 26 e 27 fazem a extração de características dos dados e a linha 28 efetua o treinamento do classificador. A função predict na linha 30 faz a classificação dos dados _testdata. As linhas 31 e 32 fazem a extração de características e a linha 33 retorna as classificações para os dados.

Validação cruzada e a matriz de confusão

A validação cruzada nada mais é do que separar uma parte dos dados para ser usado treinamento e outra parte para ser usada como teste. Para isso, o scikit-learn oferece uma função que faz a separação automaticamente dos dados. Nas linhas 4 e 5 os dados são separados para validação cruzada usando o método. Os dados para treinamento do classificador são armazenados em _traindata e _trainlabels e os dados de teste são armazenados em _testdata e _traindata. O tamanho dos conjuntos e controlado por _testsize, sendo que o valor 0.1 representa 10% de dados para teste. As linhas 7 e 8 treinam o classificador e fazem a classificação dos dados de teste. Na linha 9 é exibida a matriz de confusão que indica o desempenho do classificador. Abaixo está o resultado da execução do código acima.

alt.atheism comp.graphics sci.med soc.religion .christian Total
alt.atheism 50 1 1 1 53
comp.graphics 0 61 0 0 61
sci.med 0 3 55 1 59
soc.religion .christian 0 2 0 51 53
Total 50 67 56 53 226

Os cabeçalhos e os totais eu adicionei manualmente. O Total ao final de cada linha indica quantas amostras de determinada classe existiam no dataset de teste. Por exemplo, existem 53 amostras da classe alt.atheism nos dados para treinamento. O Total das colunas indica quantas amostras de teste foram classificados como determinada classe no dataset de teste. Por exemplo, 67 amostras foram classificadas como comp.graphics.

Validação cruzada k-fold

É uma forma de validação cruzada que divide os dados em k subconjuntos (de mesmo tamanho se possível). Em cada rodada, um subconjunto é separado para teste e os k - 1 subconjuntos restantes são usados para treinamento. A ideia é que cada amostra seja usada para teste apenas uma vez e k - 1 vezes para treinamento e ao final o erro médio é computado como uma métrica do desempenho do classificador. O scikit-learn também oferece uma função que facilita a execução da validação cruzada k-fold. As linhas 2 e 3 fazem a extração de características do dataset. A linha 5 efetua a validação cruzada k-fold nos dados, sendo que o valor k é definido por cv que nesse caso é 5. Portanto os dados serão divididos em cinco subconjuntos e testados. A linha 7 exibe os resultados obtidos para cada execução e a linha 8 exibe a média de acertos do classificador nas k execuções. O resultado da execução do código acima é [ 0.96460177 0.97345133 0.96238938 0.9579646 0.97327394] e Accuracy: 0.9663362043479118 +/- 0.01224505099481929.

Alterando o classificador

Agora é possível alterar o classificador e as características usadas para ver se há uma melhora (ou piora) na classificação. Por exemplo, é possível alterar o kernel do SVM de ‘linear’ para ‘rbf’ alterando a linha _classifier = OneVsOneClassifier(SVC(kernel = ‘linear’, randomstate = 84)) para _classifier = OneVsOneClassifier(SVC(kernel = ‘rbf’, randomstate = 84)). O resultado da validação cruzada k-fold é: [ 0.26548673 0.26548673 0.26548673 0.26548673 0.26503341] e Accuracy: 0.2653960620454501 +/- 0.0003626544730670034. A matriz de confusão é:

alt.atheism comp.graphics sci.med soc.religion .christian Total
alt.atheism 0 0 0 53 53
comp.graphics 0 0 0 61 61
sci.med 0 0 0 59 59
soc.religion .christian 0 0 0 53 53
Total 0 0 0 226 226

É possível perceber que apenas uma mudança no kernel resultou em uma piora do classificador. Por isso é importante testar os parâmetros disponíveis no classificador e as diferentes características que podem ser extraídas dos dados para verificar o desempenho da classificação.

Para hoje é isso. O código final está no GitHub. Abraços e até a próxima :)

Referências

20 Newsgroups.

Bag-of-words model.

Cross-validation: evaluating estimator performance.

Cross validation.

Feature extraction.

Tf-idf :: A Single-Page Tutorial - Information Retrieval and Text Mining.

Tokenization (lexical analysis).

Simple guide to confusion matrix terminology.

sklearn.metrics.confusion_matrix.

sklearn.feature_extraction.text.TfidfTransformer.

sklearn.feature_extraction.text.CountVectorizer.

sklearn.cross_validation.cross_val_score.

Working With Text Data.