#!/usr/bin/env python

# Task to sort AIPS calibration tables
# Written Jan2003 Enno Middelberg
# Task reads an SN table that has to have phases/rates/delays in LCP IFs only
# and scales them by a factor specified on command line

import string, sys, fileinput, re, math, tty, os
# import Gnuplot

if len(sys.argv)<7:
    print "\n ffstg_intfreqs written by Enno Middelberg Jan 2003"
    print "\n Task reads an SN table that has to have phases/rates/delays"
    print " in LCP IFs only and scales them by a factor specified on command line"
    print " On the command line, specify:"
    print " filename_lowfreq antenna_number lowfreq highfreq no_IFs highfreq_freqID\n"
    sys.exit()

# regex: string contains at least one number
matchstr=re.compile('\d+')
# the input file
file1=sys.argv[1]
# which antenna to deal with
antenna=sys.argv[2]
antint=int(antenna)
# low frequency
freq1=float(sys.argv[3])
# the target frequency
freq2=float(sys.argv[4])
# how many IFs are there?
noifs=int(sys.argv[5])
# compute the scaling factor
freqratio=freq2/freq1
# compute the maximum rate allowed
# it's 1 turn / min = 16.6 mHz at freq2
maxrate=0.0166/(freq2*1E9)
# high_freq freqid
freqid2=sys.argv[6]

print "\nLow frequency (freq1):  "+`freq1`
print "High frequency (freq2): "+`freq2`+"\t(Frequency ID: "+sys.argv[6]+")"
print "Frequency Ratio:          "+`freqratio`
print "Antenna number:           "+antenna
print "Number of IFs:            "+`noifs`
print "Maximum phase rate:       "+`maxrate`

##################################################
## My definitions, my definitions are these :-) ##
##################################################


# delete a certain entry
def delentry(freqid, timestamp):
    x=0
    while output[x][0]<timestamp or output[x][1]!=freqid:
	x=x+1
    print output[x]
    del output[x]
    
# create 2D-lists suitable for plotting and fitting
# identifier = old/new, i=IF, 
def getdata(phase_or_rate, identifier, i, starttime, stoptime):
    # Check the timerange handed over
    if stoptime<starttime:
        starttime=output[0][0]
        stoptime=output[len(output)-1][0]
    print "Starttime, Stoptime: ", starttime, stoptime
    list=[]
    x=0
    if identifier=="output":
	max=len(output)
    if identifier=="newoutput":
	max=len(newoutput)
    while x<max:
        if output[x][0]>=starttime and output[x][0]<=stoptime:
            try:
                # Decide whether to hand over phase or rate
                if phase_or_rate=="phase":
                    # Decide whether to hand over old or new phases
                    if identifier=="output":
                        list.append([output[x][0], pol2deg(output[x][2][i][0], output[x][2][i][1])])
                    else:
                        list.append([newoutput[x][0], pol2deg(newoutput[x][2][i][0], newoutput[x][2][i][1])])
                else:
                    # Decide whether to hand over old or new rates
                    if identifier=="output":
                        list.append([output[x][0], 1e12*output[x][2][i][3]])
                    else:
                        list.append([newoutput[x][0], 1e12*newoutput[x][2][i][3]])
            except(TypeError):
                pass
        x=x+1
    return list


# wait for any keystroke and return the key pressed
def getch():
        fd = sys.stdin.fileno() 
        tty_mode = tty.tcgetattr(fd)
        tty.setcbreak(fd)
        try:
                ch = os.read(fd, 1)
        finally:
                tty.tcsetattr(fd, tty.TCSAFLUSH, tty_mode)
        return ch

#Convert (real, imag) into phase
def pol2deg(x, y):
    r = math.sqrt(x**2+y**2)
    phi = 180*math.acos(x/r)/math.pi
    if y < 0.:
	phi = 360-phi
    return phi

# Convert phase (degrees) into (real, imag)
def deg2pol(phase):
    real=math.cos(math.pi*phase/180)
    imag=math.sin(math.pi*phase/180)
    return real, imag

# Return sign of a number
def sign(x):
    if x<0:
	y=-1
    if x>0:
	y=1
    if x==0:
	y=0
    return y

# Take two numbers, calculate the phase and scale it by another number
def turnphase(re, im, int):
    degrees=pol2deg(re, im)
    newdegrees=int*degrees
    new_reim=deg2pol(newdegrees)
    return new_reim
    

# Convert seconds into hms format
def time2hms(seconds):
    h=int(seconds/3600)
    m=int(seconds % 3600)/60
    s=seconds-(h*3600)-(m*60)
    hms="%2d:%2d:%4.2f" % (h, m, s)
    return hms

# Check whether the next item exists and return its index
def next(list, timestamp):
    x=0
    while list[x][0]<timestamp:
	print timestamp, list[x][0]
	x=x+1
    return x-1

# Check whether the an nth previous item exists and return its index
def previous(list, timestamp):
    x=0
    while list[x]<timestamp:
	x=x+1
    return x+1


# Return an index close to a given time in sec
# Used for finding a cursor position in plotting
# when a timerange is specified

def point_pos(i):
    x=0
    while freq1list[x][0]<i:
	x=x+1
    return x
    

#################################################
##        End of definitions section           ##
#################################################

################################################################################################
# "output" and "newoutput" have the same format:         
# [timestamp, freqid, [list of IFs], srcID]
#
# where each IF entry has real, imaginary, delay and rate
################################################################################################

# this list takes the output
output=[]

# read in table data
print "Reading data from "+file1+"...",
lines1=[]
for line in fileinput.input(file1):
    lines1.append(line)

print "done.\n"

# Process input SN table
x=0
lines=lines1
while x<len(lines):
    list=string.split(lines[x])
    # if line contains 'INDE', skip it
    if len(list)>3 and list[4]==antenna:
	# check if list starts with 2 numbers
	a=matchstr.match(list[0])
	b=matchstr.match(list[1])
	# if so, get delay and phases and calculate time stamp
	if a and b:
	    time=86400*float(list[1])
	    h=int(time/3600)
	    m=int(time % 3600)/60
	    s=time-(h*3600)-(m*60)
	    src=list[3]
	    # get all IFs
	    data=[]
	    for i in range(noifs):
		data.append(['INDE','INDE','INDE', 'INDE'])
	    for i in range(noifs):
		list=string.split(lines[x+i])
		try:
                    data[i][0]=float(list[10])
		    data[i][1]=float(list[11])
                    data[i][2]=float(list[12])
                    data[i][3]=float(list[13])
		except (ValueError):
		    pass
	    out=[time]
	    out.append("%d:%d:%4.2f" % (h, m, s))
	    out.append(data)
	    out.append(src)
	    output.append(out)
	    x=x+(noifs-1)
    x=x+1

print "\nModifying phases...", 

# Interpolate timestamps, phases and rates, scale them up and store
# them in newoutput
newoutput=[]
x=0
n_of_errors=0
while x<len(output):
    # Only interpolate if adjacent scans exist, are not too far away in time
    # and belong to the same source
    if x+1<len(output) and output[x+1][0]-output[x][0]<150 and (output[x][3]==output[x+1][3]):
	delta_t=output[x+1][0]-output[x][0]
	newtime=output[x][0]+delta_t/2
	src=output[x][3]
	ifs=[]
	try:
	    for y in range(noifs):
		# compute avg of real and imaginary of adjacent phases
		# sometimes 'INDE' numbers or other probs occur, catch them
		# compute the angle between previous and next phase measurement
		angle=(180/math.pi)*math.acos(output[x+1][2][y][0]*output[x][2][y][0]+output[x+1][2][y][1]*output[x][2][y][1])
		# compute rate and check its sign via determinant of the two phases
		# rate doesn't need to be scaled as it is stored in sec/sec in SN table
		rate=angle/delta_t # unit = deg/s
		p=sign(output[x][2][y][0]*output[x+1][2][y][1]-output[x+1][2][y][0]*output[x][2][y][1])
		rate=p*rate/(360*freq1*1E9) # unit = sec/sec
		# add half of it to the previous phase, convert to re, im and scale by freqratio
		angle=pol2deg(output[x][2][y][0], output[x][2][y][1])+p*0.5*angle
		re, im=deg2pol(angle)
		phase=deg2pol(freqratio*pol2deg(re, im))
		#interpolate delay
		delay=(output[x+1][2][y][2]+output[x][2][y][2])/2
		# write some output
		if y==0:
		    print "  ",time2hms(output[x][0]), pol2deg(output[x][2][y][0],output[x][2][y][1]), output[x][2][y][2], output[x][2][y][3]
		    print "  ",time2hms(output[x+1][0]), pol2deg(output[x+1][2][y][0],output[x+1][2][y][1]), output[x+1][2][y][2], output[x+1][2][y][3]
		    print "->", time2hms(newtime), pol2deg(re, im), delay, rate, src
		if y==0:
		    print "->", time2hms(newtime), pol2deg(phase[0], phase[1]), delay, rate, "\n"
		ifs.append([phase[0], phase[1], delay, rate])
	    newoutput.append([newtime, freq2, ifs, src])
	# catch + count the errors
	except(TypeError, ValueError):
	    n_of_errors=n_of_errors+1
    x=x+1

print "done. Number of errors: %i\n" % n_of_errors
    

if "-noplot" not in sys.argv:
    # plot phases
    g=Gnuplot.Gnuplot()
    g('set data style linespoints')
    print " 4 - move curser to the left"
    print " 5 - set windows to zoom into the plot (press once for start and stop"
    print " 6 - move cursor to the right"
    print " 9 - reset zoom window"
    print " 0 - enter time range"
    # Loop over IFs
    for i in range(noifs):
	freq1list=getdata("phase", "output", i, 0, -1)
	freq2list=getdata("phase", "newoutput", i, 0, -1)
	g.xlabel('Time in seconds')
	g.ylabel('Phase in If #'+`i`)
	# make PlotItems
	plot1=Gnuplot.Data(freq1list, title='If #'+`i`, with='lines')
	plot2=Gnuplot.Data(freq2list, title='Predicted If #'+`i`, with='lines')
	point=freq1list[0]
	plot3=Gnuplot.Data(point)
	g.plot(plot1, plot2, plot3)
	key=getch()
	# process keystroke
	x=0
	starttime_set=0
	starttime=freq1list[0][0]
	stoptime=freq1list[len(freq1list)-1][0]
	while ord(key)!=10:
	    if key=='4':
		x=(x-1) % len(freq1list)
		point=freq1list[x]
	    if key=='5':
		if starttime_set==1:
		    stoptime=freq1list[x][0]
		    starttime_set=0
		else:
		    starttime=freq1list[x][0]
		    starttime_set=1
		print "starttime_set: ", starttime_set
	    if key=='6':
		x=(x+1) % len(freq1list)
		point=freq1list[x]
	    if key=='0':
		line=string.split(raw_input ("Enter time range to plot (separated with blanks):"))
		starttime=float(line[0])
		stoptime=float(line[1])
		middle=(starttime+stoptime)/2
		x=point_pos(middle)
		point=freq1list[x]
	    if key=='9':
		starttime=freq1list[0][0]
		stoptime=freq1list[len(freq1list)-1][0]
		starttime_set=0
	    # refresh the plot
	    xrangestring='set xrange ['+`(starttime)`+':'+`(stoptime)`+']'
	    g(xrangestring)
	    plot1=Gnuplot.Data(freq1list, title='If #'+`i`, with='lines')
	    plot2=Gnuplot.Data(freq2list, title='Predicted If #'+`i`, with='lines')
	    plot3=Gnuplot.Data(point, with='points pointtype 2 pointsize 3')
	    print "current cursor position: ", time2hms(freq1list[x][0]), freq1list[x][0], freq1list[x][1]
	    #print "backwards conversion:    ", pol2deg
	    g.plot(plot1, plot2, plot3)
	    key=getch()

    # plot rates
    g=Gnuplot.Gnuplot()
    g('set data style linespoints')
    print " 4 - move curser to the left"
    print " 5 - set windows to zoom into the plot (press once for start and stop"
    print " 6 - move cursor to the right"
    print " 9 - reset zoom window"
    print " 0 - enter time range"
    # Loop over IFs
    for i in range(noifs):
	freq1list=getdata("rate", "output", i, 0, -1)
	freq2list=getdata("rate", "newoutput", i, 0, -1)
	g.xlabel('Time in seconds')
	g.ylabel('Rate of If #'+`i`+' in 1e12 sec/sec')
	# make PlotItems
	plot1=Gnuplot.Data(freq1list, title='If #'+`i`, with='lines')
	plot2=Gnuplot.Data(freq2list, title='Predicted If #'+`i`, with='lines')
	point=freq1list[0]
	plot3=Gnuplot.Data(point)
	g.plot(plot1, plot2, plot3)
	key=getch()
	# process keystroke
	x=0
	starttime_set=0
	starttime=freq1list[0][0]
	stoptime=freq1list[len(freq1list)-1][0]
	while ord(key)!=10:
	    if key=='4':
		x=(x-1) % len(freq1list)
		point=freq1list[x]
	    if key=='5':
		if starttime_set==1:
		    stoptime=freq1list[x][0]
		    starttime_set=0
		else:
		    starttime=freq1list[x][0]
		    starttime_set=1
		print "starttime_set: ", starttime_set
	    if key=='6':
		x=(x+1) % len(freq1list)
		point=freq1list[x]
	    if key=='0':
		line=string.split(raw_input ("Enter time range to plot (separated with blanks):"))
		starttime=float(line[0])
		stoptime=float(line[1])
		middle=(starttime+stoptime)/2
		x=point_pos(middle)
		point=freq1list[x]
	    if key=='9':
		starttime=freq1list[0][0]
		stoptime=freq1list[len(freq1list)-1][0]
		starttime_set=0
	    # refresh the plot
	    xrangestring='set xrange ['+`(starttime)`+':'+`(stoptime)`+']'
	    g(xrangestring)
	    plot1=Gnuplot.Data(freq1list, title='If #'+`i`, with='lines')
	    plot2=Gnuplot.Data(freq2list, title='Predicted If #'+`i`, with='lines')
	    plot3=Gnuplot.Data(point, with='points pointtype 2 pointsize 3')
	    print "current cursor position: ", time2hms(freq1list[x][0]), freq1list[x][0], freq1list[x][1]
	    #print "backwards conversion:    ", pol2deg
	    g.plot(plot1, plot2, plot3)
	    key=getch()

# Data editing is done; write solutions to an SN table suitable for AIPS
# read data from existing file, create it if it doesn't yet exist...
newlines2=[]
try:
    for x in fileinput.input('freqid_'+sys.argv[6]+'_output'):
	newlines2.append(x)
    print "Reading data from "+"freqid_"+sys.argv[6]+"_output..."
    # store number of last entry
    lastline=string.split(newlines2[len(newlines2)-2])
    max=int(lastline[0])+1
except (IOError):
    # read header from input SN table and modify it
    print "freqid_"+sys.argv[6]+"_output did not exist. Created it."
    outfile=open('freqid_'+sys.argv[6]+'_output', 'w')
    # doprint controls whether to print lines or not
    doprint=1
    for x in lines1:
	# set number of rows to a ridiculous number (TBIN doesn't care)
	if x[0:6]=="NAXIS2":
	    outfile.write("NAXIS2  =               100000 / Number of entries in table")
	if doprint==1:
	    outfile.write(x)
	# If entries follow, suppress printing
	if x[0:16]=="***BEGIN*PASS***":
	    doprint=-1
	# If end is reached, finish
	if x[0:14]=="***END*PASS***":
	    outfile.write("***END*PASS***")
	    break
    outfile.close()
    # Now we have the new table on disk, read it in
    print "Reading data from "+"freqid_"+sys.argv[6]+"_output"
    for x in fileinput.input('freqid_'+sys.argv[6]+'_output'):
	newlines2.append(x)
    max=1

# Write output
print "Writing data to freqid_"+sys.argv[6]+"_output..."

# Delete last line from newlines2...
del newlines2[len(newlines2)-1]

# ...write the table lines...
x=0
while x<len(newoutput):
    # create first line of table entry
    for y in range(noifs):
	if y==0:
	    line="%8i   %1.15E   0.000000E+00          " % (x+max, newoutput[x][0]/86400)
	    line=line+newoutput[x][3]+"          "+antenna+"          1          "+freqid2+"   0.000000E+00          0   0.000000E+00  "
	    line=line+"% 1.6E  % 1.6E  % 1.6E  % 1.6E   1.000000E+00          8\n" % (newoutput[x][2][0][0], newoutput[x][2][0][1], newoutput[x][2][0][2],newoutput[x][2][0][3])
	else:
	    line="%8i                      ''             ''         ''         ''         ''         ''             ''         ''             ''  " % (x+max)
	    line=line+"% 1.6E  % 1.6E  % 1.6E  % 1.6E   1.000000E+00          8\n" % (newoutput[x][2][y][0], newoutput[x][2][y][1], newoutput[x][2][y][2],newoutput[x][2][y][3])
	newlines2.append(line)
    x=x+1 

# ...add end mark.
newlines2.append("***END*PASS***\n")

# remove old file from disk
os.remove('freqid_'+sys.argv[6]+'_output')

# Dump data
outfile=open('freqid_'+sys.argv[6]+'_output', 'w')
for x in newlines2:
    outfile.write(x)
outfile.close()

