Calcular AUC en R?


Dado un vector de puntuaciones y un vector de etiquetas de clase reales, ¿cómo se calcula una métrica AUC de un solo número para un clasificador binario en el lenguaje R o en inglés simple?

Página 9 de "AUC: una mejor medida..." parece requerir conocer las etiquetas de la clase, y aquí está un ejemplo en MATLAB donde no entiendo

R(Actual == 1))

¿Porque R (no debe confundirse con el lenguaje R) se define como un vector pero se usa como una función?

Author: AGS, 2011-02-05

10 answers

Como han mencionado otros, puede calcular el AUC utilizando el paquete ROCR. Con el paquete ROCR también puede trazar la curva ROC, la curva de elevación y otras medidas de selección de modelos.

Puede calcular el AUC directamente sin usar ningún paquete utilizando el hecho de que el AUC es igual a la probabilidad de que un verdadero positivo sea mayor que un verdadero negativo.

Por ejemplo, si pos.scores es un vector que contiene una puntuación de los ejemplos positivos, y neg.scores es un vector conteniendo los ejemplos negativos, entonces el AUC se aproxima por:

> mean(sample(pos.scores,1000,replace=T) > sample(neg.scores,1000,replace=T))
[1] 0.7261

Dará una aproximación del AUC. También puede estimar la varianza del AUC mediante bootstrapping:

> aucs = replicate(1000,mean(sample(pos.scores,1000,replace=T) > sample(neg.scores,1000,replace=T)))
 30
Author: erik,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2011-02-15 05:44:00

El paquete ROCR calculará el AUC entre otras estadísticas:

auc.tmp <- performance(pred,"auc"); auc <- as.numeric([email protected])
 34
Author: semaj,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2013-12-01 12:18:22

Con el paquete pROC puede usar la función auc() como este ejemplo desde la página de ayuda:

> data(aSAH)
> 
> # Syntax (response, predictor):
> auc(aSAH$outcome, aSAH$s100b)
Area under the curve: 0.7314
 26
Author: J. Win.,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-05-22 20:21:03

Sin paquetes adicionales:

true_Y = c(1,1,1,1,2,1,2,1,2,2)
probs = c(1,0.999,0.999,0.973,0.568,0.421,0.382,0.377,0.146,0.11)

getROC_AUC = function(probs, true_Y){
    probsSort = sort(probs, decreasing = TRUE, index.return = TRUE)
    val = unlist(probsSort$x)
    idx = unlist(probsSort$ix)  

    roc_y = true_Y[idx];
    stack_x = cumsum(roc_y == 2)/sum(roc_y == 2)
    stack_y = cumsum(roc_y == 1)/sum(roc_y == 1)    

    auc = sum((stack_x[2:length(roc_y)]-stack_x[1:length(roc_y)-1])*stack_y[2:length(roc_y)])
    return(list(stack_x=stack_x, stack_y=stack_y, auc=auc))
}

aList = getROC_AUC(probs, true_Y) 

stack_x = unlist(aList$stack_x)
stack_y = unlist(aList$stack_y)
auc = unlist(aList$auc)

plot(stack_x, stack_y, type = "l", col = "blue", xlab = "False Positive Rate", ylab = "True Positive Rate", main = "ROC")
axis(1, seq(0.0,1.0,0.1))
axis(2, seq(0.0,1.0,0.1))
abline(h=seq(0.0,1.0,0.1), v=seq(0.0,1.0,0.1), col="gray", lty=3)
legend(0.7, 0.3, sprintf("%3.3f",auc), lty=c(1,1), lwd=c(2.5,2.5), col="blue", title = "AUC")

introduzca la descripción de la imagen aquí

 16
Author: AGS,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2013-09-28 21:01:36

Encontré que algunas de las soluciones aquí son lentas y/o confusas (y algunas de ellas no manejan los vínculos correctamente), así que escribí mi propia función data.table basada en auc_roc() en mi paquete R mltools.

library(data.table)
library(mltools)

preds <- c(.1, .3, .3, .9)
actuals <- c(0, 0, 1, 1)

auc_roc(preds, actuals)  # 0.875

auc_roc(preds, actuals, returnDT=TRUE)
   Pred CountFalse CountTrue CumulativeFPR CumulativeTPR AdditionalArea CumulativeArea
1:  0.9          0         1           0.0           0.5          0.000          0.000
2:  0.3          1         1           0.5           1.0          0.375          0.375
3:  0.1          1         0           1.0           1.0          0.500          0.875
 6
Author: Ben,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-08-28 20:02:09

A lo largo de las líneas de la respuesta de erik, también debería ser capaz de calcular el ROC directamente comparando todos los pares posibles de valores de pos.puntuaciones y neg.puntuación:

score.pairs <- merge(pos.scores, neg.scores)
names(score.pairs) <- c("pos.score", "neg.score")
sum(score.pairs$pos.score > score.pairs$neg.score) / nrow(score.pairs)

Ciertamente menos eficiente que el enfoque de muestra o el pROC::auc, pero más estable que el primero y que requiere menos instalación que el segundo.

Relacionado: cuando probé esto dio resultados similares al valor de pROC, pero no exactamente el mismo( apagado por 0.02 o así); el resultado fue más cercano a la muestra enfoque con muy alta N. Si alguien tiene ideas de por qué podría ser que estaría interesado.

 3
Author: Max Ghenis,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2013-01-15 14:10:55

Combinando código de ISL 9.6.3 Curvas ROC, junto con @J. Won.respuesta a esta pregunta y algunos lugares más, la siguiente gráfica de la curva ROC e imprime el AUC en la parte inferior derecha de la gráfica.

Debajo de probs está un vector numérico de probabilidades predichas para la clasificación binaria y test$label contiene las etiquetas verdaderas de los datos de prueba.

require(ROCR)
require(pROC)

rocplot <- function(pred, truth, ...) {
  predob = prediction(pred, truth)
  perf = performance(predob, "tpr", "fpr")
  plot(perf, ...)
  area <- auc(truth, pred)
  area <- format(round(area, 4), nsmall = 4)
  text(x=0.8, y=0.1, labels = paste("AUC =", area))

  # the reference x=y line
  segments(x0=0, y0=0, x1=1, y1=1, col="gray", lty=2)
}

rocplot(probs, test$label, col="blue")

Esto da una trama como esta:

introduzca la descripción de la imagen aquí

 3
Author: arun,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2016-07-20 22:15:28

Normalmente uso la función ROC del paquete Diagnosticmed. Me gusta el gráfico que produce. El AUC se devuelve junto con su intervalo de confianza y también se menciona en el gráfico.

ROC(classLabels,scores,Full=TRUE)
 2
Author: George Dontas,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2011-02-05 08:50:34

Actualmente la respuesta más votada es incorrecta, porque no tiene en cuenta los empates. Cuando las puntuaciones positivas y negativas son iguales, el AUC debe ser 0,5. A continuación se corrige el ejemplo.

computeAUC <- function(pos.scores, neg.scores, n_sample=100000) {
  # Args:
  #   pos.scores: scores of positive observations
  #   neg.scores: scores of negative observations
  #   n_samples : number of samples to approximate AUC

  pos.sample <- sample(pos.scores, n_sample, replace=T)
  neg.sample <- sample(neg.scores, n_sample, replace=T)
  mean(1.0*(pos.sample > neg.sample) + 0.5*(pos.sample==neg.sample))
}
 2
Author: Jussi Kujala,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-01-04 07:45:04

Puede obtener más información sobre AUROC en este post de blog de Miron Kursa :

Https://mbq.me/blog/augh-roc /

Proporciona una función rápida para AUROC:

# By Miron Kursa https://mbq.me
auroc <- function(score, bool) {
  n1 <- sum(!bool)
  n2 <- sum(bool)
  U  <- sum(rank(score)[!bool]) - n1 * (n1 + 1) / 2
  return(1 - U / n1 / n2)
}

Vamos a probarlo:

set.seed(42)
score <- rnorm(1e3)
bool  <- sample(c(TRUE, FALSE), 1e3, replace = TRUE)

pROC::auc(bool, score)
mltools::auc_roc(score, bool)
ROCR::performance(ROCR::prediction(score, bool), "auc")@y.values[[1]]
auroc(score, bool)

0.51371668847094
0.51371668847094
0.51371668847094
0.51371668847094

auroc() es 100 veces más rápido que pROC::auc() y computeAUC().

auroc() es 10 veces más rápido que mltools::auc_roc() y ROCR::performance().

print(microbenchmark(
  pROC::auc(bool, score),
  computeAUC(score[bool], score[!bool]),
  mltools::auc_roc(score, bool),
  ROCR::performance(ROCR::prediction(score, bool), "auc")@y.values,
  auroc(score, bool)
))

Unit: microseconds
                                                             expr       min
                                           pROC::auc(bool, score) 21000.146
                            computeAUC(score[bool], score[!bool]) 11878.605
                                    mltools::auc_roc(score, bool)  5750.651
 ROCR::performance(ROCR::prediction(score, bool), "auc")@y.values  2899.573
                                               auroc(score, bool)   236.531
         lq       mean     median        uq        max neval  cld
 22005.3350 23738.3447 22206.5730 22710.853  32628.347   100    d
 12323.0305 16173.0645 12378.5540 12624.981 233701.511   100   c 
  6186.0245  6495.5158  6325.3955  6573.993  14698.244   100  b  
  3019.6310  3300.1961  3068.0240  3237.534  11995.667   100 ab  
   245.4755   253.1109   251.8505   257.578    300.506   100 a   
 0
Author: Kamil Slowikowski,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-05-06 16:40:38