Learn practical skills, build real-world projects, and advance your career
import sisl
import numpy as np
import pandas as pd

import os
import plotly.graph_objects as go
#Find minimum below and above fermi level
def calculateGap(bands):
    
    #Initialize the gap limits array
    gapLimits = np.array([-10**3, 10**3], dtype=float)
    gapLimitsLoc = np.array([[0,0], [0,0]], dtype=int)
    
    #
    isAboveFermi = bands > 0
    
    for i, cond in enumerate([~isAboveFermi, isAboveFermi]):
        
        gapLimitsLoc[i, :] = np.argwhere(abs(bands) == np.min(abs(bands[cond])) )[0]
        
        iK, iWF = gapLimitsLoc[i]

        gapLimits[i] = bands[iK, iWF]
    
    gap = np.diff(gapLimits)[0]
    
    return gap, gapLimitsLoc

def calculateSpinGaps(bands):
    
    gaps = [0,0]
    gapLimitsLocs = [0,0]
    
    for i, spinBands in enumerate(bands):
        
        gaps[i], gapLimitsLocs[i] = calculateGap(spinBands)
        
    return gaps, gapLimitsLocs

def readData(fileName):

    bandsSile = sisl.get_sile(fileName)
    ticks, Ks, bands = bandsSile.read_data()
    bands = np.rollaxis(bands, axis = 1)
    
    return ticks, Ks, bands
rootDir = "../Remotes/stretchAGNRs/"
resultsDir = "unitCell"
independentVariable = "Strain"; axis = 0

allGapsInfo = []

for simDir in os.listdir(rootDir):
    
    simName = simDir
    simDir = os.path.join(rootDir, simDir, resultsDir)
    
    if not os.path.exists(simDir):
        continue
    
    bandsFiles = [ os.path.join(simDir, fileName) for fileName in sorted(os.listdir(simDir)) if ".bands" in fileName]
    
    if not independentVariable:
        
        independentVariable = "Index",
        independentVals = False
        
    if independentVariable == "Strain":
        
        relaxedGeom = sisl.get_sile(bandsFiles[0].replace(".bands", ".XV")).read_geometry()
        relaxedVec = relaxedGeom.cell[axis, :]
        independentVals = []
    
    gapEvolution = []; gapLimitsLocsEv = []
    
    for fileName in bandsFiles:
        
        #Do the things needed to get the values for the independent variable (the thing that is modifying the gap)
        if independentVariable == "Strain":
            
            stretchedVec = sisl.get_sile(fileName.replace(".bands", ".XV")).read_geometry().cell[axis,:]
            
            independentVals.append(np.linalg.norm(stretchedVec - relaxedVec)/np.linalg.norm(relaxedVec))
            
        #Do the things needed to get the gap  
        ticks, Ks, bands = readData(fileName)

        gaps, gapLimitsLocs = calculateSpinGaps(bands)

        gapEvolution.append(gaps)
        gapLimitsLocsEv.append(gapLimitsLocs)
    
    
    gapEvolution = np.array(gapEvolution)
    
    allGapsInfo.append([simName, gapEvolution, gapLimitsLocsEv, independentVals])
    
allGapsInfo = np.array(allGapsInfo)
df = pd.DataFrame(allGapsInfo[:,1:], columns = ["Gap evolution", "Gap locations", independentVariable], index=allGapsInfo[:,0])
df.head()
data = []

for index in df.index:
    
    dfRow = df.loc[index]
    
    gapEvolution, gapLocsEv, indVar = df.columns
    
    if dfRow[indVar]:
        #Sort according to the independent variable so that we can draw a line
        sortedInds = np.argsort(dfRow[indVar])
        sortedX, sortedY = np.array(dfRow[indVar])[sortedInds], np.array(dfRow[gapEvolution])[sortedInds,0]
        
        xy = {"x":sortedX, "y": sortedY}
    else:
        xy = {"y": dfRow[gapEvolution][:,0]}
    
    data.append({**xy, "name": index, "mode": "lines+markers", "hovertemplate": 'Band gap: %{y:.2f} eV' })
    
    
    #plt.plot(gapEvolution[:,1], label = simName + " spin down")
    
fig = go.Figure(
    data = [go.Scatter(data) for data in data],
    layout = go.Layout({
        "title": "Gap evolution",
        "plot_bgcolor": "white",
        "xaxis": {
            "title": "Strain",
            "showgrid": False,
            "zeroline": False
        },
        "yaxis": {
            "title": "Band gap (eV)",
            "showgrid": False,
            "zeroline": False
        }
    })
)

fig.show()
rootDir = "../Remotes/stretchAGNRs"
for direc in os.listdir(rootDir):
    
    if os.path.exists(os.path.join(rootDir, direc, "unitCell", f"0{direc}0.0.XV")):
        os.rename(os.path.join(rootDir, direc, "unitCell", f"0{direc}0.0.XV"),
                 os.path.join(rootDir, direc, "unitCell", f"0{direc}.00.XV"))