from flask import Blueprint, render_template, request, send_file
import MySQLdb
import os

library_api = Blueprint('library_api', __name__)

logic = 'OR'

@library_api.route("/deletepcap/", methods=['GET', 'POST'])
def deletepcap():

    if request.method == 'POST':
       filename = str(request.form['filename'])

       # delete file from pcap files
       sql = "delete from pcapfiles where filename='%s';" % filename
       db = MySQLdb.connect(host="localhost", passwd="betty", db="protocols")
       cr = db.cursor()
       cr.execute(sql)
       db.commit()

       # delete associated protocols
       sql = "delete from protocols where filename='%s';" % filename
       db = MySQLdb.connect(host="localhost", passwd="betty", db="protocols")
       cr = db.cursor()
       cr.execute(sql)
       db.commit()

       # delete the actual file
       delstr = 'rm /data/pcaps/' + filename + '*'
       print "delete ", delstr
       p = os.system(delstr)

    return filename

#
# Download a PCAP from the repository to the local browser
#

@library_api.route("/downloadpcap/", methods=['GET', 'POST'])
def downloadpcap():

    if request.method == 'POST':
       filename = request.form['filename']
       src = '/pcaps/' + filename
    return send_file(src, attachment_filename=filename, as_attachment=True)

#
# Upload a pcap file to Richmond, process it to obtain a text file, then index it into the database for searches
#

@library_api.route("/uploadpcap/", methods=['GET', 'POST', 'PUT'])
def uploadpcap():
    print "upload library", request.method

    if request.method == 'POST':
      f = request.files['file']
      path = '/pcaps/' + secure_filename(f.filename)

      # only save if it doesn't exist
      if not os.path.isfile(path):
         try:
           f.save(path)
         except:
           return render_template('fileerror.html')

      # run the process pcap to pcap.txt
      p1 = subprocess.call(['/data/pcaps/update-pcaptext.sh'])

      # update the database
      p2 = subprocess.call(['python', '/var/www/FlaskApps/betty/databaserefresh.py'])

      # FIXME:  Use database entries to copy to all
      startCopy(path, 'suffolk', 'npadmin')
      startCopy(path, 'portsmouth', 'npadmin')
      startCopy(path, 'roanoke', 'npadmin')
      startCopy(path, 'harding', 'npadmin')
      startCopy(path, 'bannon', 'devopsuser')
      startCopy(path, 'hayes-01', 'npadmin')
      startCopy(path, 'potel', 'jenkins')
      startCopy(path, 'wilson', 'jenkins')

      return render_template('fileuploaded.html')

    return "POST ONLY"

def startCopy(path, server, user):
    p9 = multiprocessing.Process(target=copyToReplay, args=(path, server, user,))
    p9.start()

def copyToReplay(localpath, server, user):
    remotepath = localpath
    print "copy to remote producer", localpath, server
    ssh = paramiko.SSHClient()
    ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    ssh.connect(server, username=user, password='hammerhead')
    sftp = ssh.open_sftp()
    sftp.put(localpath, remotepath)
    sftp.close()
    ssh.close()
    print "done copy", server

#
# Main library screen for upload/download/convert/index
#

@library_api.route("/loadlibrary/", methods=['GET','POST'])
def loadlibrary():
    global logic

    data=""
    value=""
    protos=""
    dp=""

    try:
       directories = os.listdir('/pcaps')
    except:
       return render_template('loadlibrary.html', data=data, value=value, protos=protos, dp=dp)

# get list of all pcaps in the dir that are also in the database
    dlist = []
    db = MySQLdb.connect(host="localhost", passwd="veronicaHammerhead27!", db="veronica")
    sql = "SELECT * FROM pcapfiles;"
    cr = db.cursor()
    cr.execute(sql)
    results = cr.fetchall()
    print results

    for r in results:
        dlist.append(r[1])
    data = list(set(dlist))
    data.sort()

# load list of protocols
    try:
       db = MySQLdb.connect(host="localhost", passwd="veronicaHammerhead27!", db="veronica")
       cr = db.cursor()

       sql = "SELECT * FROM protocols;"

       cr.execute(sql)
       results = cr.fetchall()
       allprotos = []
       for p in results:
           allprotos.append(p[1])

       protos = sorted(set(allprotos), key=str.lower)

    except:
       print "Error: unable to fetch data"

    numfiles = " Total of " + str(len(data)) + " pcap files found"

    return render_template('loadlibrary.html', data=data, protos=protos, numfiles=numfiles)

#
# Library load window for the producers, mostly identical code but no upload/download from this screen
#

@library_api.route("/loadtraffic/", methods=['GET','POST'])
def loadtraffic():
    global logic

    try:
       directories = os.listdir('/pcaps')
    except:
       return render_template('loadtraffic.html', data=data, value=value, protos=protos, dp=dp)

    sname   = request.form['servername']
    channel = request.form['channel']
    location = request.form['location']

# get list of all pcap .txt descriptions

    dlist = []
    db = MySQLdb.connect(host="localhost", passwd="veronicaHammerhead27!", db="veronica")
    sql = "SELECT * FROM pcapfiles;"
    cr = db.cursor()
    cr.execute(sql)
    results = cr.fetchall()
    for r in results:
        if r[1][-4:] == 'pcap':
           dlist.append(r[1])
    data = list(set(dlist))     # remove duplicates
    data.sort()                 # sort the list

# load list of protocols
    try:
       db = MySQLdb.connect(host="localhost", passwd="veronicaHammerhead27!", db="veronica")
       cr = db.cursor()
       sql = "SELECT * FROM protocols;"
       cr.execute(sql)
       results = cr.fetchall()
       allprotos = []
       for p in results:
           allprotos.append(p[1])

       protos = sorted(set(allprotos), key=str.lower)

    except:
       print "Error: unable to fetch data"

    debug = " Total of " + str(len(data)) + " pcap files found"

    return render_template('loadtraffic.html', data=data, protos=protos, channel=channel, debug=debug, sname=sname, location=location)

#
# Click on file in window causes this to be called. Display the file info based on the file name
#

@library_api.route('/pcaptext/', methods=['GET', 'POST'])
def pcaptext():
    fn = request.form['filename']

    sql = "select * from pcapfiles where filename='%s'" % fn
    db = MySQLdb.connect(host="localhost", passwd="veronicaHammerhead27!", db="veronica")
    cr = db.cursor()
    try:
       cr.execute(sql)
       data = cr.fetchall()
       dl = len(data)
       apstr = "<h4>" + fn + "</h4>"
       apstr = apstr + '<style>tr:nth-child(even) {background-color: #f2f2f2}</style><table width=100%>'
       for i in range(dl):
           apstr = apstr + "<td align=left>" + data[i][2] + ":</td><td align=right>" + data[i][3] + "</td><tr>"
    except:
       apstr = "bad fetch"

    apstr = apstr + "</table>"

    # get protocols for this pcap
    sql = "select * from protocols where filename='%s'" % fn
    db = MySQLdb.connect(host="localhost", passwd="veronicaHammerhead27!", db="veronica")
    cr = db.cursor()

    try:
       cr.execute(sql)
       data = cr.fetchall()
       dl = len(data)
       apstr = apstr + '<table width=100%><td align=left><h4>Protocols</h4></td><td align=right><b>packets</b></td><td align=right><b>bytes</b></td><td align=right><b>flows</b></td><tr>'
       for i in range(dl):
           apstr = apstr + "<td align=left>" + data[i][1] + "</td><td align=right>" + data[i][2] + "</td><td align=right>" + data[i][3] + "</td><td align=right>" + data[i][4] + "</td><tr>"
    except:
       apstr = "bad fetch"

    apstr = apstr + "</table>"

    return apstr


#
# set logic from OR to AND for database processing.  Handle radio button click, store in global
#

@library_api.route('/pcaplogic/', methods=['GET', 'POST'])
def pcaplogic():
    global logic
    print request.form
    logic = request.form['radio']
    return logic

#
# populate pcap protocols from database
#

@library_api.route('/pcapproto/', methods=['GET', 'POST'])
def pcapproto():
    global logic

    pr = request.form['qstring']
    queries = pr.split(',')

    if logic != 'AND':
       pq = "protocol='"+queries[0]+"'"
       for q in range(1,len(queries)):
           pq = pq + " OR " + " protocol='" + queries[q] + "'"

       sql = "select * from protocols where %s" % pq
       pstr = sql
       db = MySQLdb.connect(host="localhost", passwd="veronicaHammerhead27!", db="veronica")
       cr = db.cursor()

       try:
          cr.execute(sql)
          datal = cr.fetchall()
          dlist = []
          for d in datal:
              dlist.append(d[5])
          data = list(set(dlist))
          data.sort()
          for d in data:
              pstr = pstr + '<option value="' + d + '">' + d + '</option>'
       except:
          pass

       return pstr

    # AND option on protocol sort. Get all, then do intersection of lists
    else:
       data = ()
       dict = {}
       pstr=""
       db = MySQLdb.connect(host="localhost", passwd="veronicaHammerhead27!", db="veronica")
       totalp = len(queries)
       lists=[]
       for q in queries:
           sql = "select * from protocols where protocol='%s'" % q
           try:
              cr = db.cursor()
              cr.execute(sql)
              data = cr.fetchall()
              filelist = []
              for d in data:
                  filelist.append(d[5])
              lists.append(filelist)
           except:
              pass

       andslist = list(set.intersection(*map(set, lists)))
       for f in andslist:
           pstr = pstr + '<option value="' + f + '">' + f + '</option>'

    return pstr

