#!/usr/bin/python3
from __future__ import division
import os
import sys
import time
from lxml import etree
import lxml.etree as ET
from os import popen
from os import system


URL="http://cctop.ttk.hu/direct"
def msg1(line):
	sys.stderr.write("Usage:")
	sys.stderr.write("'"+str(sys.argv[0])+" -i fasta file [-f -s -o output_file]\n")
	sys.stderr.write("or\n")
	sys.stderr.write("'"+str(sys.argv[0])+"  --input fasta_file [--tmfilter --sigpred --output output_file]\n")
	sys.stderr.write("Example: "+str(sys.argv[0])+" -i /home/user/protein.fas -f 1 -s 1 -o /home/user/protein.txt\n")
	sys.stderr.write("Output location can be omitted to display result on standard output"+"\n")
	sys.stderr.write("Default options are -f 1 and -s 1"+"\n")
	sys.stderr.write("After submission, the script will check for results every 30 sec."+"\n")
	sys.stderr.write("Status changes are displayed."+"\n")
def getargs(line):
	infile=""
	outfile=""
	tmfilter=False
	sigpred=False
	for i in range (0, len(line)):
		if line[i][:2]=="-i":
			if i+1>=len(line):
				msg1(line)
				return ["","",False,False,False]
			infile=line[i+1]
		if line[i][:7]=="--input":
			if i+1>=len(line):
				msg1(line)
				return ["","",False,False,False]
			infile=line[i][line[i].find("=")+1:]
		if line[i][:2]=="-f":
			tmfilter=1
		if line[i][:7]=="--tmfilter":
			tmfilter=1
		if line[i][:2]=="-s":
			sigpred=1
		if line[i][:7]=="--sigpred":
			sigpred=1
		if line[i][:2]=="-o":
			if i+1>=len(line):
				msg1(line)
				return ["","",False,False,False]
			outfile=line[i+1]
		if line[i][:8]=="--output":
			if i+1>=len(line):
				msg1(line)
				return ["","",False,False,False]
			outfile=line[i][line[i].find("=")+1:]
	if outfile!="":
		try:
			f1=open(outfile,"w")
			f1.close()
		except IOError:
			sys.stderr.write("Invalid path, check output file directory path and permissions. Use absolute path."+"\n")
			sys.stderr.write(str(outfile)+"\n")
			return [infile,outfile,tmfilter,sigpred,False]
	try:
		f1=open(infile,"r")
		f1.close()
	except IOError:
		sys.stderr.write("Invalid path, check input file directory path and permissions. Use absolute path."+"\n")
		sys.stderr.write(str(infile)+"\n")
		return [infile,outfile,tmfilter,sigpred,False]
	return [infile,outfile,tmfilter,sigpred,True]
def read(fastafile):
	seq=""
	header=""
	seqs={}
	while 1:
		line=fastafile.readline()
		if line=="":
			break
		if line[0]==">":
			if seq!="" and header!="":
				seq=seq.replace(" ","")
				seqs[header]=seq
				seq=""
			header=line[1:].strip()
		else:
			seq=seq+line.strip()
	seq=seq.replace(" ","")
	header=header.replace('|', '_')
	seqs[header]=seq
	fastafile.close()
	if len(seqs)==0:
		sys.stderr.write("Invalid input file format, please provide a fasta file."+"\n")
	return seqs
def submit(header,proteinsequence,tmfilter,sigpred):
	jobID={}
	fail=[]
	url=URL+"/submit?id="+header+"&seq="+proteinsequence
	if tmfilter:
		url+="&tmfilter"
	if sigpred:
		url+="&sigpred"
	tmp=popen("wget -qO- '"+url+"'")
	content=tmp.readlines()
	ret=content[0]
	sys.stderr.write("Submitted: "+header+" ("+ret+")\n")
	return ret
def poll(header):
    tmp=popen("wget -qO- '"+URL+"/poll?hash="+header+"'")
    content=tmp.readlines()
    return content[0]
def state(query):
	st=[0,0,0,0,0,0]
	for key in query:
		if query[key][1]=="Finished":
			st[0]+=1
		if query[key][1]=="Error":
			st[1]+=1
		if query[key][1]=="Invalid":
			st[2]+=1
		if query[key][1]=="Scheduled":
			st[3]+=1
		if query[key][1]=="Running":
			st[4]+=1
		if query[key][1]=="Invalid amino acid":
			st[5]+=1
	sys.stderr.write("Scheduled: "+str(st[3])+" | ")
	sys.stderr.write("Running: "+str(st[4])+" | ")
	sys.stderr.write("Finished: "+str(st[0])+" jobs\r")
	if st[5]>0:
		sys.stderr.write("Invalid amino acid: "+str(st[5])+" jobs\n")
	if st[2]>0:
		sys.stderr.write("The server cannot find the jobid: "+str(st[2])+"\n")
	if st[1]>0:
		sys.stderr.write("CCTOP encountered an error while running: "+str(st[1])+" job\n")
def res(header):
    tmp=popen("wget -qO- '"+URL+"/result/"+header+"/xml'")
    content=tmp.readlines()
    time.sleep(1)
    return content

if len(sys.argv)<3:
	msg1(sys.argv)
elif sys.argv[1].lower()=="help" or sys.argv[1].lower()=="-h" or sys.argv[1].lower()=="-help" or sys.argv[1].lower()=="--help":
	msg1(sys.argv)
else:
	[inf,outf,tmfilter,sigpred,valid]=getargs(sys.argv)
	if valid==True:
		sequences=read(open(inf,"r"))
		ID={}
		result={}
		if len(sequences)>0:
			for key in sequences:
				if len(sequences[key]) < 4000:
					ID[key]=[submit(key,sequences[key],tmfilter,sigpred),""]
					if ID[key][0]=="Invalid amino acid":
						ID[key][1]="Invalid amino acid"
				else:
					sys.stderr.write("Length of sequence is larger than allowed for "+key+"\n")
		if len(ID)>0:
			while 1:
				for key in ID:
					if ID[key][1]!="Finished" and ID[key][0]!="Invalid amino acid":
						ID[key][1]=poll(ID[key][0])
				state(ID)
				for key in ID:
					try:
						tmp=result[key]
					except KeyError:
						if ID[key][1]=="Finished":
							result[key]=res(ID[key][0])
				quit=True
				for key in ID:
					if ID[key][1]=="Running" or ID[key][1]=="Scheduled":
						quit=False
				if quit==True:
					break
				time.sleep(10)
		if outf!="":
			o=open(outf,"w")
		for key in result:
			for i in range (1, len(result[key])):
				if outf=="":
					print(result[key][i].replace("\n",""))
				else:
					o.write(result[key][i])
		if outf!="":
			o.close()
