46da7c6bdd3b96d06934043910c59d1a3e93bf81
[matches/honours.git] / research / TCS / process.py
1 #!/usr/bin/python -u
2
3 #
4 # @file process.py
5 # @purpose Process TCS data
6 #               Takes S(E) = dI/dE
7 # @author Sam Moore
8 # @date August 2012
9 #
10
11 import sys
12 import os
13 import re # Regular expressions - for removing comments
14 import odict #ordered dictionary
15 import copy
16
17 import Gnuplot, Gnuplot.funcutils
18 import string
19 import time
20 import math
21 import cmath
22 import random
23 import numpy
24
25 gnuplot = Gnuplot.Gnuplot()
26
27 def Reset():
28         gnuplot = Gnuplot.Gnuplot()
29
30 def FindDataFiles(directory=".", depth=1, result=None):
31         if result == None:
32                 result = []
33         
34         for f in os.listdir(directory):
35                 if os.path.isdir(directory+"/"+str(f)):
36                         if depth > 1:
37                                 result += FindDataFiles(directory+"/"+str(f), depth-1, result)
38                         continue
39                 s = f.split(".")
40                 if (len(s) == 2 and s[1] == "dat"):
41                         result.append(directory+"/"+str(f))
42
43         return result
44
45 def BaseName(f):
46         a = f.split("/")
47         return a[len(a)-1]
48         
49
50 def DirectoryName(f, start=0,back=1):
51         a = f.split("/")
52         return string.join(a[start:(len(a)-back)], "/")
53
54 def GetData(filename, key=1):
55
56         if type(filename) != type(""):
57                 if type(filename) == type([]):
58                         return filename
59                 else:
60                         return [[0,0,0,0]]
61         
62
63         if os.path.isdir(filename):
64                 if os.path.exists(filename.strip("/")+"/average.dat"):
65                         os.remove(filename.strip("/")+"/average.dat")
66                 AverageAllData(filename)
67                 return GetData(filename.strip("/")+"/average.dat")
68
69         input_file = open(filename, "r")
70         data = {}
71         for line in input_file:
72                 line = re.sub("#.*", "", line).strip("\r\n ")
73                 if len(line) == 0:
74                         continue
75
76                 line = map(lambda e : float(e), line.split("\t"))
77                 
78                 if line[key] in data:
79                         for i in range(len(line)):
80                                 data[line[key]][0][i] += line[i]
81                         data[line[key]][1] += 1
82                 else:
83                         data.update({line[key] : [line, 1]})                    
84                 
85         d = map(lambda e : map(lambda f : float(f) / float(e[1][1]), e[1][0]), data.items())
86         d.sort(key = lambda e : e[key])
87         #for l in d:
88         #       print str(l)
89         return d
90
91 def DoNothing(data):
92         return data
93
94
95 def GetDataSets(directory="."):
96         data_sets = []
97         for f in os.listdir(directory):
98                 if os.path.isdir(directory+"/"+str(f)) == False:
99                         if (len(f.split(".")) > 1 and f.split(".")[1] == "dat"):
100                                 d = GetData(directory+"/"+str(f))
101                                 if len(d) > 0:
102                                         data_sets.append(d)
103         return data_sets        
104
105
106
107 def Derivative(data, a=1, b=2, sigma=None,step=1):
108         result = [[]]
109         n = 0
110         dI = [0,0]
111         dE = [0,0]
112         
113         for i in range(0, len(data),step):
114                 result[len(result)-1] = [d for d in data[i]]
115                 if (i >= step):
116                         dE[0] = data[i][a] - data[i-step][a]
117                         dI[0] = data[i][b] - data[i-step][b]
118                 else:
119                         dI[0] = None
120
121                 if (i < len(data)-step):
122                         dE[1] = data[i+step][a] - data[i][a]
123                         dI[1] = data[i+step][b] - data[i][b]
124                 else:
125                         dI[1] = None            
126
127                 if sigma != None:
128                         #print str(data[i]) + " ["+str(sigma)+"] = " + str(data[i][int(abs(sigma))])
129                         if sigma < 0:
130                                 if dI[0] != None: dI[0] -= 0.5*data[i][int(abs(sigma))]
131                                 if dI[1] != None: dI[1] -= 0.5*data[i][int(abs(sigma))]
132                         else:
133                                 if dI[0] != None: dI[0] += 0.5*data[i][int(abs(sigma))]
134                                 if dI[1] != None: dI[1] += 0.5*data[i][int(abs(sigma))]
135
136                 deltaE = 0.0
137                 deltaI = 0.0
138                 count = 0
139                 if dI[0] != None:
140                         deltaE += dE[0]
141                         deltaI += dI[0]
142                         count += 1
143                 if dI[1] != None:
144                         deltaE += dE[1]
145                         deltaI += dI[1]
146                         count += 1
147
148                 if (count > 0):
149                         deltaI /= float(count)
150                         deltaE /= float(count)
151
152
153                         if (deltaE != 0):       
154                                 result[len(result)-1][b] = (deltaI / deltaE)
155                         else:
156                                 result[len(result)-1][b] = 0.0
157                 else:
158                         result[len(result)-1][b] = 0.0
159                 result.append([])
160                         
161         return result[0:len(result)-1]
162
163 def MaxNormalise(data, u=2):    
164         result = copy.deepcopy(data)
165         if (len(data) <= 0):
166                 return result
167         maxval = max(data, key = lambda e : e[u])[u]
168         
169         if maxval == 0:
170                 return result
171
172         for d in result:
173                 d[u] = d[u] / maxval
174                 
175         return result
176         
177 def Average(data_sets, u=1): 
178         avg = odict.odict()
179         for t in data_sets:
180                 for p in t:
181                         if p[u] in avg:
182                                 #print "Already have " + str(p[u])
183                                 avg[p[u]][1] += 1
184                                 for i in range(0, len(p)):
185                                         avg[p[u]][0][i] += p[i]
186                         else:
187                                 #print "Create key for " + str(p[u])
188                                 avg.update({p[u] : [p, 1]})
189
190         for a in avg.keys():
191                 for i in range(0, len(avg[a][0])):
192                         avg[a][0][i] /= float(avg[a][1])
193         return map(lambda e : e[1][0], sorted(avg.items(), key = lambda e : e[0]))
194
195 def FullWidthAtHalfMax(data, u=1):
196         maxval = max(data, key = lambda e : e[u])
197         peak = data.index(maxval)
198         maxval = maxval[0]
199         lhs = None
200         rhs = None
201         for i in range(1, len(data)/2):
202                 if lhs == None:
203                         if (peak-i > 0 and data[peak-i] < 0.50*maxval):
204                                 lhs = data[peak-i][u]
205                 if rhs == None:
206                         if (peak+i < len(data) and data[peak+i] < 0.50*maxval):
207                                 rhs = peak+i
208                 if lhs != None and rhs != None:
209                         break
210         if rhs == None or lhs == None:
211                 return abs(data[len(data)-1][0] - data[0][0])
212         else:
213                 return abs(rhs - lhs)
214
215 def SaveData(filename, data):
216         out = open(filename, "w", 0)
217         for a in data:
218                 for i in range(0, len(a)):
219                         out.write(str(a[i]))
220                         if (i < len(a) - 1):
221                                 out.write("\t")
222                 out.write("\n")
223
224 def AverageAllData(directory, save=None, normalise=True):
225         data_sets = []
226         if save == None: save = directory+"/average.dat"
227         for f in FindDataFiles(directory):
228                 d = GetData(f)
229                 if normalise:
230                         d = MaxNormalise(d)
231                 data_sets.append(d)
232         
233         a = Average(data_sets)
234         SaveData(save, a)
235         return a
236
237 def CalibrateData(original, ammeter_scale=1e-6):
238         data = copy.deepcopy(original)
239         for i in range(0, len(data)):
240                 data[i][1] = 16.8 * float(data[i][1]) / 4000.0
241                 data[i][2] = ammeter_scale * 0.170 * float(data[i][2]) / 268.0
242                 data[i][3] = ammeter_scale * 0.170 * float(data[i][3]) / 268.0
243         return data
244
245 def ShowTCS(filename, raw=True,calibrate=True, normalise=False, show_error=False, plot=gnuplot.plot,with_="lp", step=1, output=None, title="", master_title="", smooth=0, show_peak=True, inflection=1):
246
247         if raw == False:
248                 calibrate = False
249                 normalise = False
250
251         if type(filename) == type(""): 
252                 data = GetData(filename)
253         else:
254                 data = filename
255                 filename = "tcs data"
256
257         if (title == ""):
258                 title = BaseName(filename)
259
260         if (len(data) <= 0):
261                 return data
262
263         if (smooth > 0):
264                 if type(smooth) == type([]):
265                         for i in range(smooth[0]):
266                                 data = Smooth(data, m=smooth[1])
267                 else:                   
268                         data = Smooth(data, m=smooth)
269
270
271         if calibrate: 
272                 data = CalibrateData(data)
273                 units = ["V", "uA / V"]
274         else:
275                 units = ["DAC counts", "ADC counts / DAC counts"]
276
277         if not normalise:
278                 gnuplot("set ylabel \"dI(E)/dE ("+str(units[1])+")\"")
279         else:
280                 data = MaxNormalise(data)
281                 gnuplot("set ylabel \"dI(E)/dE (normalised)\"")
282
283         if (output != None and type(output) == type("")):
284                 gnuplot("set term png size 640,480")
285                 gnuplot("set output \""+str(output)+"\"")
286
287         if master_title == "":
288                 master_title = "Total Current Spectrum S(E)"
289                 if type(filename) == type("") and plot == gnuplot.plot:
290                         if filename != "tcs data":
291                                 p = ReadParameters(filename)
292                                 if "Sample" in p:
293                                         master_title += "\\nSample: "+p["Sample"]
294
295         gnuplot("set title \""+str(master_title)+"\"")
296         gnuplot("set xlabel \"U ("+str(units[0])+")\"")
297
298
299         if raw:
300                 d = Derivative(data, 1, 2, step=step)
301         else:
302                 d = data
303
304         ymax = 0.01 + 1.2 * max(d, key=lambda e : e[2])[2]
305         ymin = -0.01 + 1.2 * min(d, key=lambda e : e[2])[2]
306         gnuplot("set yrange ["+str(ymin)+":"+str(ymax)+"]")
307
308         plotList = []
309         plotList.append(Gnuplot.Data(d, using="2:3", with_=with_,title=title))
310         
311         if (show_error):
312                 error1 = Derivative(data, 1, 2, -3,step=step)
313                 error2 = Derivative(data, 1, 2, +3,step=step)
314                 plotList.append(Gnuplot.Data(error1, using="2:3", with_=w,title="-sigma/2"))
315                 plotList.append(Gnuplot.Data(error2, using="2:3", with_=w, title="+sigma/2"))
316
317         if (show_peak):
318                 peak = SmoothPeakFind(d, ap=DoNothing, stop=1, inflection=inflection)
319                 plotList += PlotPeaks(peak,with_="l lt -1", plot=None)
320                 
321         
322
323         if (plot != None):
324                 plot(*plotList)
325                 time.sleep(0.2)
326         
327         if (output != None and type(output) == type("")):
328                 gnuplot("set term wxt")
329
330         if (plot == None):
331                 return plotList
332         return data
333
334 def ShowData(filename,calibrate=True, normalise=False, show_error=False, plot=gnuplot.plot,with_="lp", step=1, output=None, title="", master_title="Sample Current I(E)", smooth=0):
335         if type(filename) == type(""): 
336                 data = GetData(filename)
337         else:
338                 data = filename
339                 filename = "raw data"
340
341         if (title == ""):
342                 title = BaseName(filename)
343
344
345         if len(data) <= 0:
346                 return data
347
348
349         if (smooth > 0):
350                 data = Smooth(data)
351
352         if calibrate: 
353                 data = CalibrateData(data)
354                 units = ["V", "uA"]
355         else:
356                 units = ["DAC counts", "ADC counts"]
357
358         if not normalise:
359                 gnuplot("set ylabel \"I(E) ("+str(units[1])+")\"")
360         else:
361                 data = MaxNormalise(data)
362                 gnuplot("set ylabel \"I(E) (normalised)\"")
363
364         if (output != None and type(output) == type("")):
365                 gnuplot("set term png size 640,480")
366                 gnuplot("set output \""+str(output)+"\"")
367
368         gnuplot("set title \""+str(master_title)+"\"")
369         gnuplot("set xlabel \"U ("+str(units[0])+")\"")
370
371
372         #d = Derivative(data, 1, 2, step=step)
373
374         plotList = []
375         
376         plotList.append(Gnuplot.Data(data, using="2:3", with_=with_,title=title))
377         time.sleep(0.1)
378         if (show_error):
379                 error1 = copy.deepcopy(data)
380                 error2 = copy.deepcopy(data)
381                 for i in range(len(data)):
382                         #print str(data[i])
383                         error1[i][2] -= 0.50*float(data[i][3])
384                         error2[i][2] += 0.50*float(data[i][3])
385                 plotList.append(Gnuplot.Data(error1, using="2:3", with_=w,title="Error : Low bound"))
386                 plotList.append(Gnuplot.Data(error2, using="2:3", with_=w, title="Error : Upper bound"))
387         
388         if plot != None:
389                 
390                 plot(*plotList)
391                 if (output != None and type(output) == type("")):
392                         gnuplot("set term wxt") 
393                 return data
394         else:
395                 return plotList
396
397 def ReadParameters(filename):
398         parameters = odict.odict()
399         input_file = open(filename, "r")
400         for line in input_file:
401                 k = line.split("=")
402                 item = None
403                 #print str(k)
404                 if (len(k) >= 2):
405                         item = k[0].strip("# \r\n")
406                         value = k[1].strip("# \r\n")
407                         if (item in parameters):
408                                 parameters[item] = value
409                         else:
410                                 parameters.update({str(item) : value})
411         input_file.close()
412         return parameters
413
414 def PlotParameters(filename):
415         ReadParameters(filename)
416
417 def Smooth(data, m, k=2):
418         smooth = copy.deepcopy(data)
419         for i in range(len(smooth)):
420                 count = 0
421                 smooth[i][k] = 0.0
422                 for j in range(i-m,i+m):
423                         if j >= 0 and j < len(smooth):
424                                 count += 1
425                                 smooth[i][k] += data[j][k]
426                 if count > 0:
427                         smooth[i][k] = smooth[i][k] / float(count)
428                 else:
429                         smooth[i][k] = data[i][k]
430
431         return smooth
432
433 def PeakFind(data, k=2,threshold=0.00, inflection=0):
434         results = []
435         for i in range(len(data)):
436                 if i == 0 or i == len(data)-1:
437                         continue
438                 #if abs(data[i][k]) < threshold * abs(max(data, key = lambda e : abs(e[k]))[k]):
439                 #       continue
440                         
441                 left = data[i-1][k] - data[i][k]
442                 right = data[i+1][k] - data[i][k]
443                 if abs(left) < threshold*abs(data[i][k]):
444                         continue
445                 if abs(right) < threshold*abs(data[i][k]):
446                         continue
447                 if left*right > 0: 
448                         results.append(data[i] + [inflection])
449
450         if inflection > 0:
451                 results += PeakFind(Derivative(data), k=k, threshold=threshold, inflection=inflection-1)
452
453         return results
454
455 def SmoothPeakFind(data, a=1, k=2, ap=DoNothing, stop=10,smooth=5, inflection=0):
456         s = data        
457         #results = []
458         
459         peakList = []
460
461         m = 0
462         while m < stop:
463                 #results.append([])     
464                 peaks = PeakFind(ap(s),k=k, inflection=inflection)
465                 #print "m = " +str(m)
466                 for p in peaks:
467                         add = [m]
468                         [add.append(f) for f in p]
469                         
470                         if m == 0:
471                                 #print "*New peak at " + str(p)
472                                 peakList.append([add])
473                         else:
474                                 score = []
475                                 for i in range(len(peakList)):
476                                         p2 = peakList[i][len(peakList[i])-1]
477                                         if m - p2[0] > 1:
478                                                 continue
479                                         score.append([i, abs(p[a] - p2[1+a])])
480                         
481                                 score.sort(key = lambda e : e[1])
482                                 if len(score) == 0 or score[0][1] > 100:
483                                         #print "New peak at " + str(p)
484                                         peakList.append([add])          
485                                 else:
486                                         #print "Peak exists near " + str(p) + " ("+str(score[0][1])+") " + str(peakList[score[0][0]][len(peakList[score[0][0]])-1])
487                                         peakList[score[0][0]].append(add)
488                                         
489                                 
490                                         
491                         #results.append([m, []])
492                         #[results[len(results)-1].append(f) for f in p]
493                 m += 1
494                 s = Smooth(s, m=smooth,k=k)
495
496         #results.sort(key = lambda e : e[2])
497         
498         #peaks = []
499         return peakList
500         
501                         
502
503 def PlotPeaks(peaks, calibrate=True, with_="lp", plot=gnuplot.replot):
504
505         plotList = []
506         for p in peaks:
507                 p.append(copy.deepcopy(p[len(p)-1]))
508
509                 p[len(p)-1][0] += 1
510                 
511         
512                 #print "Adding " + str(p) + " to list"
513                 if len(p) >= 0:
514                         l = p[len(p)-1]
515                         if l[len(l)-1] < 1:
516                                 with_ = with_.split(" lt")[0] + " lt 9" 
517                         plotList.append(Gnuplot.Data(p, using="3:1", with_=with_))
518                         
519         
520
521         if len(plotList) > 0 and plot != None:
522                 plot(*plotList)
523                 time.sleep(0.2)
524                 
525         #print str(plotList)
526         #for p in peaks:
527         #       p = p[0:len(p)-1]
528         return plotList
529                 
530 def main():
531         
532         if (len(sys.argv) < 2):
533                 sys.stderr.write(sys.argv[0] + " - Require arguments (filename)\n")
534                 return 1
535
536         i = 1
537         plotFunc = ShowTCS
538         normalise = False
539         title = ""
540         master_title = ""
541         while i < len(sys.argv):
542                 if sys.argv[i] == "--raw":
543                         plotFunc = ShowData
544                 elif sys.argv[i] == "--tcs":
545                         plotFunc = lambda e : ShowTCS(e, show_peak=False)
546                 elif sys.argv[i] == "--output":
547                         if i+1 >= len(sys.argv):
548                                 sys.stderr.write("Need argument for "+sys.argv[i]+" switch\n")
549                                 sys.exit(1)
550                         gnuplot("set term postscript colour")
551                         gnuplot("set output \""+sys.argv[i+1]+"\"")
552                         i += 1
553                 elif sys.argv[i] == "--wxt":
554                         gnuplot("set term wxt")
555                 elif sys.argv[i] == "--normalise":
556                         normalise = True
557                 elif sys.argv[i] == "--unnormalise":
558                         normalise = False
559                 elif sys.argv[i] == "--title":
560                         if i+1 >= len(sys.argv):
561                                 sys.stderr.write("Need argument for "+sys.argv[i]+" switch\n")
562                                 sys.exit(1)
563                         title = sys.argv[i+1]
564                         i += 1
565                 elif sys.argv[i] == "--master_title":
566                         if i+1 >= len(sys.argv):
567                                 sys.stderr.write("Need argument for "+sys.argv[i]+" switch\n")
568                                 sys.exit(1)
569                         master_title = sys.argv[i+1]
570                         i += 1
571                 else:
572                         plotFunc(sys.argv[i], plot=gnuplot.replot, normalise=normalise, title=title, master_title=master_title)
573
574                 i += 1
575
576         print "Done. Press enter to exit, or type name of file to save as."
577         out = sys.stdin.readline().strip("\t\r\n #")
578         if out != "":
579                 gnuplot("set term postscript colour")
580                 gnuplot("set output \""+out+"\"")
581                 gnuplot.replot()
582
583
584 def ModelTCS(f, sigma, Emin, Emax, dE):
585         data = []
586         E = Emin
587         while E < Emax:
588                 S = (1 - sigma(0))*f(-E) + FuncIntegrate(lambda e : f(e - E) * FuncDerivative(sigma, E, dE), Emin, Emax, dE)
589                 data.append([0.00, E, S,0.00])
590                 E += dE
591         return data
592
593 def IntegrateTCS(data, imin, imax=0, di=1):
594         i = imin
595         if imax == 0:
596                 imax = len(data)-1
597         total = 0.0
598         
599         while i < imax:
600                 total += data[i][2] * (data[i+1][1] - data[i][1])
601                 i += di
602         return total
603
604 def FuncIntegrate(f, xmin, xmax, dx):
605         x = xmin
606         total = 0.0
607         while x <= xmax:
608                 total += f(x) * dx
609                 x += dx
610         return total
611
612 def FuncDerivative(f, x, dx):
613         return 0.50*(f(x+dx) - f(x-dx))/dx
614
615 def FitTCS(data, min_mse=1e-4, max_fail=100, max_adjust=4,divide=10, plot=gnuplot.plot,smooth=0):
616         if type(data) == type(""):
617                 d = GetData(data)
618                 d = CalibrateData(d)
619                 d = MaxNormalise(d)
620                 d = Derivative(d)
621         else:
622                 d = data
623
624         if smooth != 0:
625                 if type(smooth) == type([]):
626                         for _ in range(smooth[0]):
627                                 d = Smooth(d, m=smooth[1])              
628                 else:
629                         d = Smooth(d, m=smooth)
630
631         
632
633         plotItems = ShowTCS(d, raw=False,smooth=smooth,plot=None)
634         plotItems.append(None)
635         
636         peaks = SmoothPeakFind(d, smooth=5, stop=1, inflection=0)
637         peaks.sort(key = lambda e : e[len(e)-1][1])
638
639         fits = []
640
641         for i in range(0,len(peaks)):
642                 
643                 p = peaks[i]
644                 l = p[len(p)-1]
645                 if l[len(l)-1] == 0:
646                         fits.append([l[3], l[2], 1.0])
647                 else:
648                         if i-2 >= 0:
649                                 l = peaks[i-2][len(peaks[i-2])]
650                                 if l[len(l)-1] == 0:
651                                         fits.append([l[3], l[2], 1.0])
652                         if i+2 <= len(peaks)-1:
653                                 l = peaks[i+2][len(peaks[i+2])]
654                                 if l[len(l)-1] == 0:
655                                         fits.append([l[3], l[2], 1.0])
656                         
657         for i in range(len(fits)):
658                 left = 2.0
659                 right = 2.0
660                 if i > 0:
661                         left = fits[i-1][1] - fits[i][1]
662                 if i < len(fits)-1:
663                         right = fits[i+1][1] - fits[i][1]
664
665                 fits[i][2] = min([abs(0.5*left), abs(0.5*right)])
666                         
667
668         #print "Fits are " + str(fits)
669         #stdin.readline()
670
671         
672
673         def tcs(E):
674                 total = 0.0
675                 for f in fits:
676                         dt = f[0] * gaussian(E - f[1], f[2])
677                         #print " Increase total by " + str(dt)
678                         total += dt
679                 #print "tcs returns " + str(total)
680                 return total
681
682         mse = 1
683         old_mse = 1
684         cycle = 0
685         failcount = 0
686
687         
688         adjust = 1.0
689         
690         iterations = 0
691                 
692         while failcount < max_fail and mse > min_mse:
693                 i = random.randint(0, len(fits)-1)                      
694                 j = random.randint(0, len(fits[i])-1)
695                         #while j == 1:
696                         #       j = random.randint(0, len(fits[i])-1)
697
698                 #print "Adjust " + str(i) + ","+str(j) + ": Iteration " + str(iterations) + " mse: " + str(mse)
699                 
700                 old = fits[i][j]
701                 old_mse = mse
702
703                 fits[i][j] += adjust * (random.random() - 0.50)
704                 if i == 2:
705                         while fits[i][j] <= 0.0005:
706                                 fits[i][j] = adjust * (random.random() - 0.50)
707
708
709                 model = table(lambda e : [0.00, e, tcs(e), 0.00], 0, 16.8, divide*d[len(d)-1][1]/(len(d)))
710                 mse = MeanSquareError(model, d[0::divide])
711                 if mse >= old_mse:
712                         fits[i][j] = old
713                         failcount += 1
714                         if failcount > max_fail / 2:
715                                 if adjust > 1.0/(2.0**max_adjust):
716                                         adjust /= 2
717                         mse = old_mse
718                 else:
719                                 #adjust /= 2.0
720                         failcount = 0
721
722                         
723                 iterations += 1
724         
725
726         #model = table(lambda e : [0.00, e, tcs(e), 0.00], 0, 16.8, 16.8/len(d))
727         plotItems[len(plotItems)-1] = Gnuplot.Data(model, using="2:3", with_="l lt 3", title="model")
728         time.sleep(0.1)
729
730         fits.sort(key = lambda e : e[0] * gaussian(0, e[2]), reverse=True)
731         
732         if plot != None:
733                 gnuplot("set title \"MSE = "+str(mse)+"\\nfailcount = "+str(failcount)+"\\nadjust = "+str(adjust)+"\"")
734                 gnuplot.plot(*plotItems)
735
736                 return [fits, model]
737         else:
738                 return [fits, model,plotItems]
739
740
741         #return model
742
743 def SaveFit(filename, fit):
744         out = open(filename, "w", 0)
745         out.write("# TCS Fit\n")
746         
747         for f in fit:
748                 out.write(str(f[0]) + "\t" + str(f[1]) + "\t" + str(f[2]) + "\n")
749         
750         out.close()
751
752
753 def LoadFit(filename):
754         infile = open(filename, "r")
755         if (infile.readline() != "# TCS Fit\n"):
756                 sys.stderr.write("Error loading fit from file " + str(filename) + "\n")
757                 sys.exit(0)
758
759         fit = []
760         while True:
761                 f = infile.readline().strip("# \r\n\t").split("\t")
762                 if len(f) != 3:
763                         break
764                 fit.append(map(float, f))
765                 
766         infile.close()
767         fit.sort(key = lambda e : e[0] * gaussian(0, e[2]), reverse=True)
768         #def model(e):
769         #       total = 0.0
770         #       for f in fit:
771         #               total += f[0] * gaussian(e - f[1], f[2])
772         #       return total
773
774         #return table(lambda e : [0.0, e, model(e), 0.0], 0.0, 16.8, 16.8/400)
775         return fit
776
777 def MeanSquareError(model, real, k = 2):
778         mse = 0.0
779         for i in range(len(real)):
780                 
781                 mse += (model[i][k] - real[i][k])**2
782                 
783         mse /= len(model)
784         return mse
785
786 def delta(x):
787         if (x == 0):
788                 return 1.0
789         else:
790                 return 0.0
791
792 def table(f, xmin, xmax, dx):
793         result = []
794         x = xmin
795         while (x <= xmax):
796                 result.append(f(x))
797                 x += dx
798         return result
799
800 def gaussian(x, sigma):
801         if (sigma == 0.0):
802                 return 0.0
803         return math.exp(- (x**2.0)/(2.0 * sigma**2.0)) / (sigma * (2.0 * math.pi)**0.50)
804
805 def step(x, sigma, T):
806         if T == 0:
807                 return 1.0
808         return 1.0 / (math.exp((x - sigma)/T) + 1.0)
809
810 if __name__ == "__main__":
811         sys.exit(main())

UCC git Repository :: git.ucc.asn.au