library(ggraph)
library(scatterpie)
load_seurat_obj <- function(datafile,metafile,data.id)
{
  require(Seurat)
  meta.df = fread(metafile)
  data.df = fread(datafile)
  data.mat = as.matrix(data.df[,-1]);rownames(data.mat) = data.df[[1]]
  data.mat = data.mat[!grepl("^ERCC-",rownames(data.mat)),]
  meta.df = meta.df[match(colnames(data.mat),meta.df[[1]]),]
  rownames(meta.df) = colnames(data.mat)
  colnames(meta.df) = gsub("characteristics..","",colnames(meta.df))
  
  # create seurat object
  seu = CreateSeuratObject(counts = data.mat,project = data.id,min.cells = 3, min.features = 10)
  seu = AddMetaData(object = seu,metadata = meta.df)
  head(seu@meta.data)
  
  # normalize
  seu = SCTransform(seu, verbose = FALSE)
  seu = FindVariableFeatures(object = seu, 
                             selection.method = "vst", nfeatures = 3000, verbose = FALSE)
  seu <- RunPCA(seu, features = VariableFeatures(object = seu))
  seu <- RunUMAP(object = seu, reduction = "pca",dims = 1:10)
  seu <- RunTSNE(object = seu, reduction = "pca",dims = 1:10)
  return(seu)
}

get_mixture_dimred <- function(seu,dimred)
{
  require(ggplot2)
  require(scatterpie)
  # check cell mixtures 
  cell.col = c("H1975.cells","H2228.cells","HCC827.cells")
  cell.mat = seu@meta.data[,cell.col]
  cell.mat = t(apply(cell.mat,1,function(x) x/sum(x)))
  
  xy = Embeddings(seu,dimred);
  colnames(xy) = c("x","y")
  
  pdat = as.data.frame(cbind(xy,cell.mat))
  
  pobj = ggplot() + geom_scatterpie(aes(x=x, y=y), data=pdat,cols=cell.col,color = NA,alpha = 0.5) + 
    coord_equal() + theme_bw()
  return(pobj)
}

get_cluster_sil <- function(mod,d)
{
  mod = mod[mod %in% colnames(d)]
  len = length(mod)
  
  if (len > 1)
  {
    ai = rowSums(d[match(mod,rownames(d)),match(mod,colnames(d))],na.rm = TRUE)/(len-1)
    bi = rowMeans(d[match(mod,rownames(d)),which(!(colnames(d) %in% mod))],na.rm = TRUE)
    
    si = (bi - ai)/apply(cbind(ai,bi),1,max)
    
  }else{
    si = 0
  }
  savg = mean(si)
  return(savg)
}

evaluate_partition <- function(msc.res,seu)
{
  require(fpc)
  
  # check cell mixtures 
  cell.col = c("H1975.cells","H2228.cells","HCC827.cells")
  cell.mat = seu@meta.data[,cell.col]
  cell.mat = t(apply(cell.mat,1,function(x) x/sum(x)))
  
  # get MSC partition at first layer
  htbl = subset(msc.res$module.table,parent.cluster == "M0")
  modules = msc.res$modules[match(htbl$cluster.name,names(msc.res$modules))]
  vec = rep(NA,vcount(msc.res$g))
  for (i in 1:length(modules)) vec[V(msc.res$g)$name %in% modules[[i]]] = names(modules)[i]
  
  # collate all clustering results
  clus.df = data.frame(seu@meta.data[,c(grep("(.*)_snn_res\\.(.*)$",colnames(seu@meta.data)),which(colnames(seu@meta.data) == "SC3_clusters"))],
                       MSC = vec[match(colnames(seu),V(msc.res$g)$name)])
  
  # create distances based on mixture
  ii = which(is.na(cell.mat),arr.ind = TRUE)
  if (nrow(ii) > 0) 
  {
    cell.mat = cell.mat[-unique(ii[,1]),]
    clus.df = clus.df[-unique(ii[,1]),]
  }
  
  # get cluster quality index
  cstats = vector("list",ncol(clus.df))
  names(cstats) = colnames(clus.df)
  for (i in 1:ncol(clus.df))
  {
    cls = as.character(clus.df[[i]]);lvs = setdiff(unique(cls),NA)
    cls = match(cls,lvs)
    nas = is.na(cls)
    cls = cls[!nas]
    d = dist(cell.mat[!nas,])
    cstats[[i]] = cluster.stats(d = d, clustering = cls,noisecluster=FALSE,silhouette = TRUE, G2 = FALSE, G3 = FALSE,
                                wgap=TRUE, sepindex=TRUE, sepprob=0.1,sepwithnoise=TRUE)
  }
  
  return(cstats)
}

run_cidr <- function(tags)
{
  require(cidr)
  #grp = unique(sce$phenoid)
  #cmap = rainbow(length(grp));names(cmap)=grp
  #cols = cmap[sce$phenoid];
  sData <- scDataConstructor(tags)
  sData <- determineDropoutCandidates(sData)
  sData <- wThreshold(sData)
  sData <- scDissim(sData)
  sData <- scPCA(sData)
  
  sData <- nPC(sData)
  nCluster(sData)
  
  sData <- scCluster(sData)
  
  return(sData)
}

run_benchmark_clusters <- function(seu,n.cores = 4,dimred = "pca",dims = 1:10,sc3.seed = 1234)
{
  ## Seurat workflow
  seu <- FindNeighbors(seu,reduction = dimred, dims = dims)
  seu <- FindClusters(seu, resolution = c(0.4,0.8,1.2))
  seu.cls.col = c("SCT_snn_res.0.4","SCT_snn_res.0.8","SCT_snn_res.1.2")
  
  ## SC3
  library(SC3)
  sce = as.SingleCellExperiment(seu)
  rowData(sce)$feature_symbol = rownames(sce)
  counts(sce) = as.matrix(counts(sce))
  logcounts(sce) = as.matrix(logcounts(sce))
  set.seed(sc3.seed)
  sce = sc3(object = sce,k_estimator = TRUE,biology = FALSE,n_cores = n.cores)
  
  seu@meta.data$SC3_clusters = colData(sce)[[grep("^sc3_",colnames(colData(sce)))]]
  
  return(seu)
}


# function to extract 
extract_first_split <- function(msc.obj,use.valid.gate = F,min.size = 5)
{
  require(igraph)
  bg = V(msc.obj$cell.network)$name
  
  if (use.valid.gate) 
  {
    valid.tbl = subset(msc.obj$module.table,is.valid)
  }else{
    valid.tbl = subset(msc.obj$module.table,module.size >= min.size)
  }
  
  # find root cluster
  hobj = graph_from_data_frame(valid.tbl[,c("cluster.name","parent.cluster")],directed = T)
  kout = igraph::degree(hobj,mode = "out");kin = igraph::degree(hobj,mode = "in")
  root.id = names(kout)[(kout == 0 & kin > 1)]
  if (length(root.id) != 1)
  {
    root.id = names(kout)[(kout == 0 & kin > 1) | (kout == 1 & kin > 1)]
    root.size = valid.tbl$module.size[match(root.id,valid.tbl$cluster.name)]
    root.id = root.id[which.max(root.size)]
  }
  
  child.id = subset(valid.tbl,parent.cluster == root.id)$cluster.name
  mods = msc.obj$modules[child.id]
  
  vec = rep(NA,length(bg));names(vec) = bg
  for (i in 1:length(mods)) vec[bg %in% mods[[i]]] = names(mods)[i]
  
  list(mods,vec)
  
}

# function to calculate normalized MI
NMI <- function(c,t){
  
  if (is.character(t) | is.factor(t)) {t = match(t,unique(t))}
  if (is.character(c) | is.factor(c)) {c = match(c,unique(c))}
  
  n <- length(c) # = length(t)
  r <- length(unique(c))
  g <- length(unique(t))
  
  N <- matrix(0,nrow = r , ncol = g)
  for(i in 1:r){
    for (j in 1:g){
      N[i,j] = sum(t[c == i] == j)
    }
  }
  
  N_t <- colSums(N)
  N_c <- rowSums(N)
  
  B <- (1/n)*log(t( t( (n*N) / N_c ) / N_t))
  W <- B*N
  I <- sum(W,na.rm = T) 
  
  
  
  H_c <- sum((1/n)*(N_c * log(N_c/n)) , na.rm = T)
  H_t <- sum((1/n)*(N_t * log(N_t/n)) , na.rm = T)    
  
  nmi <- I/sqrt(H_c * H_t)
  
  return (nmi)
}

## adaptation of cluster purity code from Tian et al 2019 (the cellbench paper) 
cal_entropy=function(x){
  freqs <- table(x)/length(x)
  freqs = freqs[freqs>0]
  return(-sum(freqs * log(freqs)))
}

get_cluster_purity=function(cls,grp){
  # cls = clusters (list or named vector), grp = gold standard labels
  res = NULL
  if (is.list(cls))
  {
    vec = setdiff(unique(grp),NA)
    cls.ent = rep(NA,length(vec))
    names(cls.ent) = vec
    for (i in 1:length(cls.ent)) 
    {
      uec = c()
      labs = names(grp)[grp == vec[i]]
      for (j in 1:length(cls))
      {
        uec = c(uec,rep(names(cls)[j],length(intersect(cls[[j]],labs))))
      }
      cls.ent[i] = cal_entropy(uec);
      rm(uec,labs)
    }
    
    out = mean(cls.ent,na.rm = TRUE)
    
  }else{
    vec = setdiff(unique(grp),NA)
    cls.ent = rep(NA,length(vec))
    names(cls.ent) = vec
    for (i in 1:length(cls.ent)) cls.ent[i] = cal_entropy(cls[which(grp == vec[i])])
    out = mean(cls.ent,na.rm = TRUE)
    
  }
  res = list(clusterwise.purity = cls.ent,overall.purity = out)
  return(res)
}

get_cluster_accuracy=function(cls,grp){
  # cls = clusters (list or named vector), grp = gold standard labels
  if (is.list(cls))
  {
    vec = names(cls)
    grp.ent = rep(NA,length(vec));names(grp.ent) = vec;
    for (i in 1:length(vec))
    {
      grp.ent[i] = cal_entropy(grp[names(grp) %in% cls[[i]]])
    }
    out = mean(grp.ent,na.rm = TRUE)
  }else{
    vec = unique(cls)
    grp.ent = rep(NA,length(vec));names(grp.ent) = vec;
    for (i in 1:length(vec)) grp.ent[i] = cal_entropy(grp[which(cls == vec[i])])
    out = mean(grp.ent,na.rm = TRUE)
  }
  
  res = list(clusterwise.accuracy = grp.ent,overall.accuracy = out)
  return(res)
}

make_summary_plots <- function(dat,nm,ylab)
{
  library(ggbeeswarm)
  library(reshape)
  pdata = melt(dat);colnames(pdata) = c("data.id","method","mi")
  
  pobj.1 = ggplot(data = pdata) + 
    geom_bar(aes(x = data.id,y = mi,fill = method),stat = "identity",position = "dodge",colour = "black") + theme_bw() + 
    labs(x = "Data ID",y = ylab) + 
    guides(fill = guide_legend(title = "Method",ncol = 4),colour = "none") + 
    scale_fill_discrete(labels = nm) + 
    theme(axis.text.x = element_text(size = 18,angle = 45,vjust = 1,hjust = 1),axis.text.y = element_text(size = 18),
          axis.title.y = element_text(size = 20),
          axis.title.x = element_blank(),
          legend.title = element_text(size = 19),legend.text = element_text(size = 15),
          plot.subtitle = element_text(size = 27,face = "bold",hjust = -0.05),
          legend.position = "bottom")
  
  pobj.2 = ggplot() + geom_quasirandom(data = subset(pdata,data.id != "Koh*"),aes(x = method,y = mi,colour = data.id)) +
    labs(x = "Methods",y = ylab) + 
    guides(colour = guide_legend(title = "Data ID",ncol = 3)) + 
    scale_x_discrete(labels = nm) + 
    geom_boxplot(data = pdata,aes(x = method,y = mi),width = 0.2,alpha = 0.2,outlier.colour = NA) + theme_bw() + 
    theme(axis.text.x = element_text(size = 16,angle = 45,vjust = 1,hjust = 1),axis.text.y = element_text(size = 15),
          axis.title = element_text(size = 20),legend.title = element_text(size = 19),legend.text = element_text(size = 15),
          plot.subtitle = element_text(size = 27,face = "bold",hjust = -0.05),
          legend.position = "bottom")
  
  list(pobj.1,pobj.2)
}

### network evaluations
membership.to.mem <- function(membership)
{
  cls <- setdiff(unique(membership),NA)
  mem <- Matrix::Matrix(0,nrow = length(membership),ncol = length(cls))
  for (i in 1:length(cls)) mem[which(membership == cls[i]),i] <- 1;
  colnames(mem) <- cls;
  rownames(mem) <- names(membership)
  return(mem)
}
evaluate_true_cluster_connectivity <- function(g,cls.t)
{
  adj= get.adjacency(g)
  
  mem.t = membership.to.mem(cls.t)
  mem.t = mem.t[match(rownames(adj),rownames(mem.t)),]
  
  # cluster to cluster link ratio
  adj.count = t(mem.t) %*% adj %*% mem.t
  diag(adj.count) = 1/2 * diag(adj.count)
  
  # per group ratio
  inter.link = colSums(adj.count) - diag(adj.count)
  intra.link = diag(adj.count)
  
  # per node ratio
  adj.per.node = t(mem.t) %*% adj 
  intra.per.node = colSums(t(mem.t) * adj.per.node)
  inter.per.node = colSums(adj.per.node) - intra.per.node
  ratio.per.node = as.vector(t(mem.t) %*% cbind(intra.per.node/(inter.per.node + intra.per.node))/colSums(mem.t))
  names(ratio.per.node) = colnames(mem.t)
  # put together data
  grp.size = colSums(mem.t)
  link.df = data.frame(group.id = names(grp.size),group.size = grp.size,
                       intra.link = intra.link[names(grp.size)],inter.link = inter.link[names(grp.size)],
                       intra.inter.ratio = intra.link[names(grp.size)]/(inter.link[names(grp.size)] + intra.link[names(grp.size)]),
                       intra.inter.ratio.per.node = ratio.per.node[names(grp.size)]
  )
  return(link.df)
}
