88a536fba8bbf7242db5ba6b7138376fb55d6b32
[matches/honours.git] / research / TCS / process.py
1 #!/usr/bin/python -u
2
3 #
4 # @file process.py
5 # @purpose Process TCS data
6 #               Takes S(E) = dI/dE
7 # @author Sam Moore
8 # @date August 2012
9 #
10
11 import sys
12 import os
13 import re # Regular expressions - for removing comments
14 import odict #ordered dictionary
15 import copy
16
17 import Gnuplot, Gnuplot.funcutils
18 import string
19 import time
20 import math
21 import cmath
22 import random
23 import numpy
24
25 gnuplot = Gnuplot.Gnuplot()
26
27 def Reset():
28         gnuplot = Gnuplot.Gnuplot()
29
30 def FindDataFiles(directory=".", depth=1, result=None):
31         if result == None:
32                 result = []
33         
34         for f in os.listdir(directory):
35                 if os.path.isdir(directory+"/"+str(f)):
36                         if depth > 1:
37                                 result += FindDataFiles(directory+"/"+str(f), depth-1, result)
38                         continue
39                 s = f.split(".")
40                 if (len(s) == 2 and s[1] == "dat"):
41                         result.append(directory+"/"+str(f))
42
43         return result
44
45 def BaseName(f):
46         a = f.split("/")
47         return a[len(a)-1]
48         
49
50 def DirectoryName(f, start=0,back=1):
51         a = f.split("/")
52         return string.join(a[start:(len(a)-back)], "/")
53
54 def GetData(filename, key=1):
55         input_file = open(filename, "r")
56         data = {}
57         for line in input_file:
58                 line = re.sub("#.*", "", line).strip("\r\n ")
59                 if len(line) == 0:
60                         continue
61
62                 line = map(lambda e : float(e), line.split("\t"))
63                 
64                 if line[key] in data:
65                         for i in range(len(line)):
66                                 data[line[key]][0][i] += line[i]
67                         data[line[key]][1] += 1
68                 else:
69                         data.update({line[key] : [line, 1]})                    
70                 
71         d = map(lambda e : map(lambda f : float(f) / float(e[1][1]), e[1][0]), data.items())
72         d.sort(key = lambda e : e[key])
73         #for l in d:
74         #       print str(l)
75         return d
76
77 def DoNothing(data):
78         return data
79
80 def AverageAllDataSets(directory=".", function=DoNothing):
81         dirs = {}
82         for f in os.listdir(directory):
83                 if os.path.isdir(directory+"/"+str(f)) == True:
84                         data_set = []
85                         for datafile in os.listdir(directory+"/"+str(f)):
86                                 if datafile.split(".")[1] == "dat":
87                                         data_set.append(GetData(f))
88
89                         avg = Average(data_set)
90                         dirs.update({f : avg})
91         return dirs
92
93 def GetDataSets(directory="."):
94         data_sets = []
95         for f in os.listdir(directory):
96                 if os.path.isdir(directory+"/"+str(f)) == False:
97                         if (len(f.split(".")) > 1 and f.split(".")[1] == "dat"):
98                                 d = GetData(directory+"/"+str(f))
99                                 if len(d) > 0:
100                                         data_sets.append(d)
101         return data_sets        
102
103
104
105 def Derivative(data, a=1, b=2, sigma=None,step=1):
106         result = [[]]
107         n = 0
108         dI = [0,0]
109         dE = [0,0]
110         
111         for i in range(0, len(data),step):
112                 result[len(result)-1] = [d for d in data[i]]
113                 if (i >= step):
114                         dE[0] = data[i][a] - data[i-step][a]
115                         dI[0] = data[i][b] - data[i-step][b]
116                 else:
117                         dI[0] = None
118
119                 if (i < len(data)-step):
120                         dE[1] = data[i+step][a] - data[i][a]
121                         dI[1] = data[i+step][b] - data[i][b]
122                 else:
123                         dI[1] = None            
124
125                 if sigma != None:
126                         #print str(data[i]) + " ["+str(sigma)+"] = " + str(data[i][int(abs(sigma))])
127                         if sigma < 0:
128                                 if dI[0] != None: dI[0] -= 0.5*data[i][int(abs(sigma))]
129                                 if dI[1] != None: dI[1] -= 0.5*data[i][int(abs(sigma))]
130                         else:
131                                 if dI[0] != None: dI[0] += 0.5*data[i][int(abs(sigma))]
132                                 if dI[1] != None: dI[1] += 0.5*data[i][int(abs(sigma))]
133
134                 deltaE = 0.0
135                 deltaI = 0.0
136                 count = 0
137                 if dI[0] != None:
138                         deltaE += dE[0]
139                         deltaI += dI[0]
140                         count += 1
141                 if dI[1] != None:
142                         deltaE += dE[1]
143                         deltaI += dI[1]
144                         count += 1
145
146                 if (count > 0):
147                         deltaI /= float(count)
148                         deltaE /= float(count)
149
150
151                         if (deltaE != 0):       
152                                 result[len(result)-1][b] = (deltaI / deltaE)
153                         else:
154                                 result[len(result)-1][b] = 0.0
155                 else:
156                         result[len(result)-1][b] = 0.0
157                 result.append([])
158                         
159         return result[0:len(result)-1]
160
161 def MaxNormalise(data, u=2):    
162         result = copy.deepcopy(data)
163         if (len(data) <= 0):
164                 return result
165         maxval = max(data, key = lambda e : e[u])[u]
166
167         for d in result:
168                 d[u] = d[u] / maxval
169                 
170         return result
171         
172 def Average(data_sets, u=1): 
173         avg = odict.odict()
174         for t in data_sets:
175                 for p in t:
176                         if p[u] in avg:
177                                 #print "Already have " + str(p[u])
178                                 avg[p[u]][1] += 1
179                                 for i in range(0, len(p)):
180                                         avg[p[u]][0][i] += p[i]
181                         else:
182                                 #print "Create key for " + str(p[u])
183                                 avg.update({p[u] : [p, 1]})
184
185         for a in avg.keys():
186                 for i in range(0, len(avg[a][0])):
187                         avg[a][0][i] /= float(avg[a][1])
188         return map(lambda e : e[1][0], sorted(avg.items(), key = lambda e : e[0]))
189
190 def FullWidthAtHalfMax(data, u=1):
191         maxval = max(data, key = lambda e : e[u])
192         peak = data.index(maxval)
193         maxval = maxval[0]
194         lhs = None
195         rhs = None
196         for i in range(1, len(data)/2):
197                 if lhs == None:
198                         if (peak-i > 0 and data[peak-i] < 0.50*maxval):
199                                 lhs = data[peak-i][u]
200                 if rhs == None:
201                         if (peak+i < len(data) and data[peak+i] < 0.50*maxval):
202                                 rhs = peak+i
203                 if lhs != None and rhs != None:
204                         break
205         if rhs == None or lhs == None:
206                 return abs(data[len(data)-1][0] - data[0][0])
207         else:
208                 return abs(rhs - lhs)
209
210 def SaveData(filename, data):
211         out = open(filename, "w", 0)
212         for a in data:
213                 for i in range(0, len(a)):
214                         out.write(str(a[i]))
215                         if (i < len(a) - 1):
216                                 out.write("\t")
217                 out.write("\n")
218
219 def AverageAllData(directory, save=None):
220         data_sets = []
221         if save == None: save = directory+"/average.dat"
222         for d in FindDataFiles(directory):
223                 data_sets.append(GetData(d))
224         
225         a = Average(data_sets)
226         SaveData(save, a)
227         return a
228
229 def CalibrateData(original, ammeter_scale=1e-6):
230         data = copy.deepcopy(original)
231         for i in range(0, len(data)):
232                 data[i][1] = 16.8 * float(data[i][1]) / 4000.0
233                 data[i][2] = ammeter_scale * 0.170 * float(data[i][2]) / 268.0
234                 data[i][3] = ammeter_scale * 0.170 * float(data[i][3]) / 268.0
235         return data
236
237 def ShowTCS(filename, raw=True,calibrate=True, normalise=False, show_error=False, plot=gnuplot.plot,with_="lp", step=1, output=None, title="", master_title="", smooth=0, show_peak=True, inflection=1):
238
239         if raw == False:
240                 calibrate = False
241                 normalise = False
242
243         if type(filename) == type(""): 
244                 data = GetData(filename)
245         else:
246                 data = filename
247                 filename = "tcs data"
248
249         if (title == ""):
250                 title = BaseName(filename)
251
252         if (len(data) <= 0):
253                 return data
254
255         if (smooth > 0):
256                 if type(smooth) == type([]):
257                         for i in range(smooth[0]):
258                                 data = Smooth(data, m=smooth[1])
259                 else:                   
260                         data = Smooth(data, m=smooth)
261
262
263         if calibrate: 
264                 data = CalibrateData(data)
265                 units = ["V", "uA / V"]
266         else:
267                 units = ["DAC counts", "ADC counts / DAC counts"]
268
269         if not normalise:
270                 gnuplot("set ylabel \"dI(E)/dE ("+str(units[1])+")\"")
271         else:
272                 data = MaxNormalise(data)
273                 gnuplot("set ylabel \"dI(E)/dE (normalised)\"")
274
275         if (output != None and type(output) == type("")):
276                 gnuplot("set term png size 640,480")
277                 gnuplot("set output \""+str(output)+"\"")
278
279         if master_title == "":
280                 master_title = "Total Current Spectrum S(E)"
281                 if type(filename) == type("") and plot == gnuplot.plot:
282                         if filename != "tcs data":
283                                 p = ReadParameters(filename)
284                                 if "Sample" in p:
285                                         master_title += "\\nSample: "+p["Sample"]
286
287         gnuplot("set title \""+str(master_title)+"\"")
288         gnuplot("set xlabel \"U ("+str(units[0])+")\"")
289
290
291         if raw:
292                 d = Derivative(data, 1, 2, step=step)
293         else:
294                 d = data
295
296         ymax = 0.01 + 1.2 * max(d, key=lambda e : e[2])[2]
297         ymin = -0.01 + 1.2 * min(d, key=lambda e : e[2])[2]
298         gnuplot("set yrange ["+str(ymin)+":"+str(ymax)+"]")
299
300         plotList = []
301         plotList.append(Gnuplot.Data(d, using="2:3", with_=with_,title=title))
302         
303         if (show_error):
304                 error1 = Derivative(data, 1, 2, -3,step=step)
305                 error2 = Derivative(data, 1, 2, +3,step=step)
306                 plotList.append(Gnuplot.Data(error1, using="2:3", with_=w,title="-sigma/2"))
307                 plotList.append(Gnuplot.Data(error2, using="2:3", with_=w, title="+sigma/2"))
308
309         if (show_peak):
310                 peak = SmoothPeakFind(d, ap=DoNothing, stop=1, inflection=inflection)
311                 plotList += PlotPeaks(peak,with_="l lt -1", plot=None)
312                 
313         
314
315         if (plot != None):
316                 plot(*plotList)
317                 time.sleep(0.2)
318         
319         if (output != None and type(output) == type("")):
320                 gnuplot("set term wxt")
321
322         if (plot == None):
323                 return plotList
324         return data
325
326 def ShowData(filename,calibrate=True, normalise=False, show_error=False, plot=gnuplot.plot,with_="lp", step=1, output=None, title="", master_title="Sample Current I(E)", smooth=0):
327         if type(filename) == type(""): 
328                 data = GetData(filename)
329         else:
330                 data = filename
331                 filename = "raw data"
332
333         if (title == ""):
334                 title = BaseName(filename)
335
336
337         if len(data) <= 0:
338                 return data
339
340
341         if (smooth > 0):
342                 data = Smooth(data)
343
344         if calibrate: 
345                 data = CalibrateData(data)
346                 units = ["V", "uA"]
347         else:
348                 units = ["DAC counts", "ADC counts"]
349
350         if not normalise:
351                 gnuplot("set ylabel \"I(E) ("+str(units[1])+")\"")
352         else:
353                 data = MaxNormalise(data)
354                 gnuplot("set ylabel \"I(E) (normalised)\"")
355
356         if (output != None and type(output) == type("")):
357                 gnuplot("set term png size 640,480")
358                 gnuplot("set output \""+str(output)+"\"")
359
360         gnuplot("set title \""+str(master_title)+"\"")
361         gnuplot("set xlabel \"U ("+str(units[0])+")\"")
362
363
364         #d = Derivative(data, 1, 2, step=step)
365
366         plotList = []
367         
368         plotList.append(Gnuplot.Data(data, using="2:3", with_=with_,title=title))
369         time.sleep(0.1)
370         if (show_error):
371                 error1 = copy.deepcopy(data)
372                 error2 = copy.deepcopy(data)
373                 for i in range(len(data)):
374                         #print str(data[i])
375                         error1[i][2] -= 0.50*float(data[i][3])
376                         error2[i][2] += 0.50*float(data[i][3])
377                 plotList.append(Gnuplot.Data(error1, using="2:3", with_=w,title="Error : Low bound"))
378                 plotList.append(Gnuplot.Data(error2, using="2:3", with_=w, title="Error : Upper bound"))
379         
380         if plot != None:
381                 
382                 plot(*plotList)
383                 if (output != None and type(output) == type("")):
384                         gnuplot("set term wxt") 
385                 return data
386         else:
387                 return plotList
388
389 def ReadParameters(filename):
390         parameters = odict.odict()
391         input_file = open(filename, "r")
392         for line in input_file:
393                 k = line.split("=")
394                 item = None
395                 #print str(k)
396                 if (len(k) >= 2):
397                         item = k[0].strip("# \r\n")
398                         value = k[1].strip("# \r\n")
399                         if (item in parameters):
400                                 parameters[item] = value
401                         else:
402                                 parameters.update({str(item) : value})
403         input_file.close()
404         return parameters
405
406 def PlotParameters(filename):
407         ReadParameters(filename)
408
409 def Smooth(data, m, k=2):
410         smooth = copy.deepcopy(data)
411         for i in range(len(smooth)):
412                 count = 0
413                 smooth[i][k] = 0.0
414                 for j in range(i-m,i+m):
415                         if j >= 0 and j < len(smooth):
416                                 count += 1
417                                 smooth[i][k] += data[j][k]
418                 if count > 0:
419                         smooth[i][k] = smooth[i][k] / float(count)
420                 else:
421                         smooth[i][k] = data[i][k]
422
423         return smooth
424
425 def PeakFind(data, k=2,threshold=0.00, inflection=0):
426         results = []
427         for i in range(len(data)):
428                 if i == 0 or i == len(data)-1:
429                         continue
430                 #if abs(data[i][k]) < threshold * abs(max(data, key = lambda e : abs(e[k]))[k]):
431                 #       continue
432                         
433                 left = data[i-1][k] - data[i][k]
434                 right = data[i+1][k] - data[i][k]
435                 if abs(left) < threshold*abs(data[i][k]):
436                         continue
437                 if abs(right) < threshold*abs(data[i][k]):
438                         continue
439                 if left*right > 0: 
440                         results.append(data[i] + [inflection])
441
442         if inflection > 0:
443                 results += PeakFind(Derivative(data), k=k, threshold=threshold, inflection=inflection-1)
444
445         return results
446
447 def SmoothPeakFind(data, a=1, k=2, ap=DoNothing, stop=10,smooth=5, inflection=0):
448         s = data        
449         #results = []
450         
451         peakList = []
452
453         m = 0
454         while m < stop:
455                 #results.append([])     
456                 peaks = PeakFind(ap(s),k=k, inflection=inflection)
457                 #print "m = " +str(m)
458                 for p in peaks:
459                         add = [m]
460                         [add.append(f) for f in p]
461                         
462                         if m == 0:
463                                 #print "*New peak at " + str(p)
464                                 peakList.append([add])
465                         else:
466                                 score = []
467                                 for i in range(len(peakList)):
468                                         p2 = peakList[i][len(peakList[i])-1]
469                                         if m - p2[0] > 1:
470                                                 continue
471                                         score.append([i, abs(p[a] - p2[1+a])])
472                         
473                                 score.sort(key = lambda e : e[1])
474                                 if len(score) == 0 or score[0][1] > 100:
475                                         #print "New peak at " + str(p)
476                                         peakList.append([add])          
477                                 else:
478                                         #print "Peak exists near " + str(p) + " ("+str(score[0][1])+") " + str(peakList[score[0][0]][len(peakList[score[0][0]])-1])
479                                         peakList[score[0][0]].append(add)
480                                         
481                                 
482                                         
483                         #results.append([m, []])
484                         #[results[len(results)-1].append(f) for f in p]
485                 m += 1
486                 s = Smooth(s, m=smooth,k=k)
487
488         #results.sort(key = lambda e : e[2])
489         
490         #peaks = []
491         return peakList
492         
493                         
494
495 def PlotPeaks(peaks, calibrate=True, with_="lp", plot=gnuplot.replot):
496
497         plotList = []
498         for p in peaks:
499                 p.append(copy.deepcopy(p[len(p)-1]))
500
501                 p[len(p)-1][0] += 1
502                 
503         
504                 #print "Adding " + str(p) + " to list"
505                 if len(p) >= 0:
506                         l = p[len(p)-1]
507                         if l[len(l)-1] < 1:
508                                 with_ = with_.split(" lt")[0] + " lt 9" 
509                         plotList.append(Gnuplot.Data(p, using="3:1", with_=with_))
510                         
511         
512
513         if len(plotList) > 0 and plot != None:
514                 plot(*plotList)
515                 time.sleep(0.2)
516                 
517         #print str(plotList)
518         #for p in peaks:
519         #       p = p[0:len(p)-1]
520         return plotList
521                 
522 def main():
523         
524         if (len(sys.argv) < 2):
525                 sys.stderr.write(sys.argv[0] + " - Require arguments (filename)\n")
526                 return 1
527
528         i = 1
529         plotFunc = ShowTCS
530         normalise = False
531         title = ""
532         master_title = ""
533         while i < len(sys.argv):
534                 if sys.argv[i] == "--raw":
535                         plotFunc = ShowData
536                 elif sys.argv[i] == "--tcs":
537                         plotFunc = ShowTCS
538                 elif sys.argv[i] == "--output":
539                         if i+1 >= len(sys.argv):
540                                 sys.stderr.write("Need argument for "+sys.argv[i]+" switch\n")
541                                 sys.exit(1)
542                         gnuplot("set term postscript colour")
543                         gnuplot("set output \""+sys.argv[i+1]+"\"")
544                         i += 1
545                 elif sys.argv[i] == "--wxt":
546                         gnuplot("set term wxt")
547                 elif sys.argv[i] == "--normalise":
548                         normalise = True
549                 elif sys.argv[i] == "--unnormalise":
550                         normalise = False
551                 elif sys.argv[i] == "--title":
552                         if i+1 >= len(sys.argv):
553                                 sys.stderr.write("Need argument for "+sys.argv[i]+" switch\n")
554                                 sys.exit(1)
555                         title = sys.argv[i+1]
556                         i += 1
557                 elif sys.argv[i] == "--master_title":
558                         if i+1 >= len(sys.argv):
559                                 sys.stderr.write("Need argument for "+sys.argv[i]+" switch\n")
560                                 sys.exit(1)
561                         master_title = sys.argv[i+1]
562                         i += 1
563                 else:
564                         plotFunc(sys.argv[i], plot=gnuplot.replot, normalise=normalise, title=title, master_title=master_title)
565
566                 i += 1
567
568         print "Done. Press enter to exit, or type name of file to save as."
569         out = sys.stdin.readline().strip("\t\r\n #")
570         if out != "":
571                 gnuplot("set term postscript colour")
572                 gnuplot("set output \""+out+"\"")
573                 gnuplot.replot()
574
575
576 def ModelTCS(f, sigma, Emin, Emax, dE):
577         data = []
578         E = Emin
579         while E < Emax:
580                 S = (1 - sigma(0))*f(-E) + FuncIntegrate(lambda e : f(e - E) * FuncDerivative(sigma, E, dE), Emin, Emax, dE)
581                 data.append([0.00, E, S,0.00])
582                 E += dE
583         return data
584
585 def IntegrateTCS(data, imin, imax=0, di=1):
586         i = imin
587         if imax == 0:
588                 imax = len(data)-1
589         total = 0.0
590         
591         while i < imax:
592                 total += data[i][2] * (data[i+1][1] - data[i][1])
593                 i += di
594         return total
595
596 def FuncIntegrate(f, xmin, xmax, dx):
597         x = xmin
598         total = 0.0
599         while x <= xmax:
600                 total += f(x) * dx
601                 x += dx
602         return total
603
604 def FuncDerivative(f, x, dx):
605         return 0.50*(f(x+dx) - f(x-dx))/dx
606
607 def FitTCS(data, min_mse=1e-4, max_fail=100, max_adjust=4,divide=10, plot=gnuplot.plot,smooth=0):
608         if type(data) == type(""):
609                 d = GetData(data)
610                 d = CalibrateData(d)
611                 d = MaxNormalise(d)
612                 d = Derivative(d)
613         else:
614                 d = data
615
616         if smooth != 0:
617                 if type(smooth) == type([]):
618                         for _ in range(smooth[0]):
619                                 d = Smooth(d, m=smooth[1])              
620                 else:
621                         d = Smooth(d, m=smooth)
622
623         
624
625         plotItems = ShowTCS(d, raw=False,smooth=smooth,plot=None)
626         plotItems.append(None)
627         
628         peaks = SmoothPeakFind(d, smooth=5, stop=1, inflection=0)
629         peaks.sort(key = lambda e : e[len(e)-1][1])
630
631         fits = []
632
633         for i in range(0,len(peaks)):
634                 
635                 p = peaks[i]
636                 l = p[len(p)-1]
637                 if l[len(l)-1] == 0:
638                         fits.append([l[3], l[2], 1.0])
639                 else:
640                         if i-2 >= 0:
641                                 l = peaks[i-2][len(peaks[i-2])]
642                                 if l[len(l)-1] == 0:
643                                         fits.append([l[3], l[2], 1.0])
644                         if i+2 <= len(peaks)-1:
645                                 l = peaks[i+2][len(peaks[i+2])]
646                                 if l[len(l)-1] == 0:
647                                         fits.append([l[3], l[2], 1.0])
648                         
649         for i in range(len(fits)):
650                 left = 2.0
651                 right = 2.0
652                 if i > 0:
653                         left = fits[i-1][1] - fits[i][1]
654                 if i < len(fits)-1:
655                         right = fits[i+1][1] - fits[i][1]
656
657                 fits[i][2] = min([abs(0.5*left), abs(0.5*right)])
658                         
659
660         #print "Fits are " + str(fits)
661         #stdin.readline()
662
663         
664
665         def tcs(E):
666                 total = 0.0
667                 for f in fits:
668                         dt = f[0] * gaussian(E - f[1], f[2])
669                         #print " Increase total by " + str(dt)
670                         total += dt
671                 #print "tcs returns " + str(total)
672                 return total
673
674         mse = 1
675         old_mse = 1
676         cycle = 0
677         failcount = 0
678
679         
680         adjust = 1.0
681         
682         iterations = 0
683                 
684         while failcount < max_fail and mse > min_mse:
685                 i = random.randint(0, len(fits)-1)                      
686                 j = random.randint(0, len(fits[i])-1)
687                         #while j == 1:
688                         #       j = random.randint(0, len(fits[i])-1)
689
690                 #print "Adjust " + str(i) + ","+str(j) + ": Iteration " + str(iterations) + " mse: " + str(mse)
691                 
692                 old = fits[i][j]
693                 old_mse = mse
694
695                 fits[i][j] += adjust * (random.random() - 0.50)
696                 if i == 2:
697                         while fits[i][j] <= 0.0005:
698                                 fits[i][j] = adjust * (random.random() - 0.50)
699
700
701                 model = table(lambda e : [0.00, e, tcs(e), 0.00], 0, 16.8, divide*d[len(d)-1][1]/(len(d)))
702                 mse = MeanSquareError(model, d[0::divide])
703                 if mse >= old_mse:
704                         fits[i][j] = old
705                         failcount += 1
706                         if failcount > max_fail / 2:
707                                 if adjust > 1.0/(2.0**max_adjust):
708                                         adjust /= 2
709                         mse = old_mse
710                 else:
711                                 #adjust /= 2.0
712                         failcount = 0
713
714                         
715                 iterations += 1
716         
717
718         #model = table(lambda e : [0.00, e, tcs(e), 0.00], 0, 16.8, 16.8/len(d))
719         plotItems[len(plotItems)-1] = Gnuplot.Data(model, using="2:3", with_="l lt 3", title="model")
720         time.sleep(0.1)
721
722         fits.sort(key = lambda e : e[0] * gaussian(0, e[2]), reverse=True)
723         
724         if plot != None:
725                 gnuplot("set title \"MSE = "+str(mse)+"\\nfailcount = "+str(failcount)+"\\nadjust = "+str(adjust)+"\"")
726                 gnuplot.plot(*plotItems)
727
728                 return [fits, model]
729         else:
730                 return [fits, model,plotItems]
731
732
733         #return model
734
735 def SaveFit(filename, fit):
736         out = open(filename, "w", 0)
737         out.write("# TCS Fit\n")
738         
739         for f in fit:
740                 out.write(str(f[0]) + "\t" + str(f[1]) + "\t" + str(f[2]) + "\n")
741         
742         out.close()
743
744
745 def LoadFit(filename):
746         infile = open(filename, "r")
747         if (infile.readline() != "# TCS Fit\n"):
748                 sys.stderr.write("Error loading fit from file " + str(filename) + "\n")
749                 sys.exit(0)
750
751         fit = []
752         while True:
753                 f = infile.readline().strip("# \r\n\t").split("\t")
754                 if len(f) != 3:
755                         break
756                 fit.append(map(float, f))
757                 
758         infile.close()
759         fit.sort(key = lambda e : e[0] * gaussian(0, e[2]), reverse=True)
760         #def model(e):
761         #       total = 0.0
762         #       for f in fit:
763         #               total += f[0] * gaussian(e - f[1], f[2])
764         #       return total
765
766         #return table(lambda e : [0.0, e, model(e), 0.0], 0.0, 16.8, 16.8/400)
767         return fit
768
769 def MeanSquareError(model, real, k = 2):
770         mse = 0.0
771         for i in range(len(real)):
772                 
773                 mse += (model[i][k] - real[i][k])**2
774                 
775         mse /= len(model)
776         return mse
777
778 def delta(x):
779         if (x == 0):
780                 return 1.0
781         else:
782                 return 0.0
783
784 def table(f, xmin, xmax, dx):
785         result = []
786         x = xmin
787         while (x <= xmax):
788                 result.append(f(x))
789                 x += dx
790         return result
791
792 def gaussian(x, sigma):
793         if (sigma == 0.0):
794                 return 0.0
795         return math.exp(- (x**2.0)/(2.0 * sigma**2.0)) / (sigma * (2.0 * math.pi)**0.50)
796
797 def step(x, sigma, T):
798         if T == 0:
799                 return 1.0
800         return 1.0 / (math.exp((x - sigma)/T) + 1.0)
801
802 if __name__ == "__main__":
803         sys.exit(main())

UCC git Repository :: git.ucc.asn.au