ARGH
[matches/honours.git] / thesis / figures / tcs / plots / process.py~
1 #!/usr/bin/python -u
2
3 #
4 # @file process.py
5 # @purpose Process TCS data
6 #               Takes S(E) = dI/dE
7 # @author Sam Moore
8 # @date August 2012
9 #
10
11 import sys
12 import os
13 import re # Regular expressions - for removing comments
14 import odict #ordered dictionary
15 import copy
16
17 import Gnuplot, Gnuplot.funcutils
18 import string
19 import time
20 import math
21 import cmath
22 import random
23 import numpy
24
25 gnuplot = Gnuplot.Gnuplot()
26
27 def Reset():
28         gnuplot = Gnuplot.Gnuplot()
29
30 def FindDataFiles(directory=".", depth=1, result=None):
31         if result == None:
32                 result = []
33         
34         for f in os.listdir(directory):
35                 if os.path.isdir(directory+"/"+str(f)):
36                         if depth > 1:
37                                 result += FindDataFiles(directory+"/"+str(f), depth-1, result)
38                         continue
39                 s = f.split(".")
40                 if (len(s) == 2 and s[1] == "dat"):
41                         result.append(directory+"/"+str(f))
42
43         return result
44
45 def BaseName(f):
46         a = f.split("/")
47         return a[len(a)-1]
48         
49
50 def DirectoryName(f, start=0,back=1):
51         a = f.split("/")
52         return string.join(a[start:(len(a)-back)], "/")
53
54 def GetData(filename, key=1):
55
56         if type(filename) != type(""):
57                 if type(filename) == type([]):
58                         return filename
59                 else:
60                         return [[0,0,0,0]]
61         
62
63         if os.path.isdir(filename):
64                 if os.path.exists(filename.strip("/")+"/average.dat"):
65                         os.remove(filename.strip("/")+"/average.dat")
66                 AverageAllData(filename)
67                 return GetData(filename.strip("/")+"/average.dat")
68
69         input_file = open(filename, "r")
70         data = {}
71         for line in input_file:
72                 line = re.sub("#.*", "", line).strip("\r\n ")
73                 if len(line) == 0:
74                         continue
75
76                 line = map(lambda e : float(e), line.split("\t"))
77                 
78                 if line[key] in data:
79                         for i in range(len(line)):
80                                 data[line[key]][0][i] += line[i]
81                         data[line[key]][1] += 1
82                 else:
83                         data.update({line[key] : [line, 1]})                    
84                 
85         d = map(lambda e : map(lambda f : float(f) / float(e[1][1]), e[1][0]), data.items())
86         d.sort(key = lambda e : e[key])
87         #for l in d:
88         #       print str(l)
89         return d
90
91 def DoNothing(data):
92         return data
93
94
95 def GetDataSets(directory="."):
96         data_sets = []
97         for f in os.listdir(directory):
98                 if os.path.isdir(directory+"/"+str(f)) == False:
99                         if (len(f.split(".")) > 1 and f.split(".")[1] == "dat"):
100                                 d = GetData(directory+"/"+str(f))
101                                 if len(d) > 0:
102                                         data_sets.append(d)
103         return data_sets        
104
105
106
107 def Derivative(data, a=1, b=2, sigma=None,step=1):
108         #print "Derivative called"
109         result = [[]]
110         n = 0
111         dI = [0,0]
112         dE = [0,0]
113         
114         for i in range(0, len(data),step):
115                 result[len(result)-1] = [d for d in data[i]]
116                 if (i >= step):
117                         dE[0] = data[i][a] - data[i-step][a]
118                         dI[0] = data[i][b] - data[i-step][b]
119                 else:
120                         dI[0] = None
121
122                 if (i < len(data)-step):
123                         dE[1] = data[i+step][a] - data[i][a]
124                         dI[1] = data[i+step][b] - data[i][b]
125                 else:
126                         dI[1] = None            
127
128                 if sigma != None:
129                         #print str(data[i]) + " ["+str(sigma)+"] = " + str(data[i][int(abs(sigma))])
130                         if sigma < 0:
131                                 if dI[0] != None: dI[0] -= 0.5*data[i][int(abs(sigma))]
132                                 if dI[1] != None: dI[1] -= 0.5*data[i][int(abs(sigma))]
133                         else:
134                                 if dI[0] != None: dI[0] += 0.5*data[i][int(abs(sigma))]
135                                 if dI[1] != None: dI[1] += 0.5*data[i][int(abs(sigma))]
136
137                 deltaE = 0.0
138                 deltaI = 0.0
139                 count = 0
140                 if dI[0] != None:
141                         deltaE += dE[0]
142                         deltaI += dI[0]
143                         count += 1
144                 if dI[1] != None:
145                         deltaE += dE[1]
146                         deltaI += dI[1]
147                         count += 1
148
149                 if (count > 0):
150                         deltaI /= float(count)
151                         deltaE /= float(count)
152
153
154                         if (deltaE != 0):       
155                                 result[len(result)-1][b] = (deltaI / deltaE)
156                         else:
157                                 result[len(result)-1][b] = 0.0
158                 else:
159                         result[len(result)-1][b] = 0.0
160                 result.append([])
161                         
162         return result[0:len(result)-1]
163
164 def MaxNormalise(data, u=2):    
165         result = copy.deepcopy(data)
166         if (len(data) <= 0):
167                 return result
168         maxval = max(data, key = lambda e : e[u])[u]
169         
170         if maxval == 0:
171                 return result
172
173         for d in result:
174                 d[u] = d[u] / maxval
175                 
176         return result
177         
178 def Average(data_sets, u=1): 
179         avg = odict.odict()
180         for t in data_sets:
181                 for p in t:
182                         if p[u] in avg:
183                                 #print "Already have " + str(p[u])
184                                 avg[p[u]][1] += 1
185                                 for i in range(0, len(p)):
186                                         avg[p[u]][0][i] += p[i]
187                         else:
188                                 #print "Create key for " + str(p[u])
189                                 avg.update({p[u] : [p, 1]})
190
191         for a in avg.keys():
192                 for i in range(0, len(avg[a][0])):
193                         avg[a][0][i] /= float(avg[a][1])
194         return map(lambda e : e[1][0], sorted(avg.items(), key = lambda e : e[0]))
195
196 def FullWidthAtHalfMax(data, u=1):
197         maxval = max(data, key = lambda e : e[u])
198         peak = data.index(maxval)
199         maxval = maxval[0]
200         lhs = None
201         rhs = None
202         for i in range(1, len(data)/2):
203                 if lhs == None:
204                         if (peak-i > 0 and data[peak-i] < 0.50*maxval):
205                                 lhs = data[peak-i][u]
206                 if rhs == None:
207                         if (peak+i < len(data) and data[peak+i] < 0.50*maxval):
208                                 rhs = peak+i
209                 if lhs != None and rhs != None:
210                         break
211         if rhs == None or lhs == None:
212                 return abs(data[len(data)-1][0] - data[0][0])
213         else:
214                 return abs(rhs - lhs)
215
216 def SaveData(filename, data):
217         out = open(filename, "w", 0)
218         for a in data:
219                 for i in range(0, len(a)):
220                         out.write(str(a[i]))
221                         if (i < len(a) - 1):
222                                 out.write("\t")
223                 out.write("\n")
224
225 def AverageAllData(directory, save=None, normalise=True):
226         data_sets = []
227         if save == None: save = directory+"/average.dat"
228         for f in FindDataFiles(directory):
229                 d = GetData(f)
230                 if normalise:
231                         d = MaxNormalise(d)
232                 data_sets.append(d)
233         
234         a = Average(data_sets)
235         SaveData(save, a)
236         return a
237
238 def CalibrateData(original, ammeter_scale=1e-6):
239         data = copy.deepcopy(original)
240         for i in range(0, len(data)):
241                 data[i][1] = 16.8 * float(data[i][1]) / 4000.0
242                 data[i][2] = ammeter_scale * 0.170 * float(data[i][2]) / 268.0
243                 data[i][3] = ammeter_scale * 0.170 * float(data[i][3]) / 268.0
244         return data
245
246 def ShowTCS(filename, raw=True,calibrate=True, normalise=False, show_error=False, plot=gnuplot.plot,with_="lp", step=1, output=None, title="", master_title="", smooth=0, show_peak=False, inflection=1):
247
248         if raw == False:
249                 calibrate = False
250                 normalise = False
251
252         if type(filename) == type(""): 
253                 data = GetData(filename)
254         else:
255                 data = filename
256                 filename = "tcs data"
257
258         if (title == ""):
259                 title = BaseName(filename)
260
261         if (len(data) <= 0):
262                 return data
263
264         if (smooth > 0):
265                 if type(smooth) == type([]):
266                         for i in range(smooth[0]):
267                                 data = Smooth(data, m=smooth[1])
268                 else:                   
269                         data = Smooth(data, m=smooth)
270
271
272         if calibrate: 
273                 data = CalibrateData(data)
274                 units = ["Volts", "uA / V"]
275         else:
276                 units = ["DAC counts", "ADC counts / DAC counts"]
277
278         if not normalise:
279                 gnuplot("set ylabel \"S(U) ("+str(units[1])+")\"")
280         else:
281                 data = MaxNormalise(data)
282                 gnuplot("set ylabel \"S(U) (arb. units)\"")
283
284         if (output != None and type(output) == type("")):
285                 gnuplot("set term png size 640,480")
286                 gnuplot("set output \""+str(output)+"\"")
287
288         if master_title == "":
289                 master_title = "Total Current Spectrum S(U)"
290                 if type(filename) == type("") and plot == gnuplot.plot:
291                         if filename != "tcs data":
292                                 p = ReadParameters(filename)
293                                 if "Sample" in p:
294                                         master_title += "\\nSample: "+p["Sample"]
295
296         gnuplot("set title \""+str(master_title)+"\"")
297         gnuplot("set xlabel \"U ("+str(units[0])+")\"")
298         gnuplot("set xrange [0:16]")
299         gnuplot("set mxtics 10")
300         gnuplot("set mytics 10")
301
302         if raw:
303                 d = Derivative(data, 1, 2, step=step)
304         else:
305                 d = data
306
307         ymax = 0.01 + 1.2 * max(d, key=lambda e : e[2])[2]
308         ymin = -0.01 + 1.2 * min(d, key=lambda e : e[2])[2]
309         gnuplot("set yrange ["+str(ymin)+":"+str(ymax)+"]")
310
311         plotList = []
312         plotList.append(Gnuplot.Data(d, using="2:3", with_=with_,title=title))
313         
314         if (show_error):
315                 error1 = Derivative(data, 1, 2, -3,step=step)
316                 error2 = Derivative(data, 1, 2, +3,step=step)
317                 plotList.append(Gnuplot.Data(error1, using="2:3", with_=w,title="-sigma/2"))
318                 plotList.append(Gnuplot.Data(error2, using="2:3", with_=w, title="+sigma/2"))
319
320         if (show_peak):
321                 peak = SmoothPeakFind(d, ap=DoNothing, stop=1, inflection=inflection)
322                 plotList += PlotPeaks(peak,with_="l lt -1", plot=None)
323                 
324         
325
326         if (plot != None):
327                 plot(*plotList)
328                 time.sleep(0.2)
329         
330         if (output != None and type(output) == type("")):
331                 gnuplot("set term wxt")
332
333         if (plot == None):
334                 return plotList
335         return data
336
337 def ShowData(filename,calibrate=True, normalise=False, show_error=False, plot=gnuplot.plot,with_="lp", step=1, output=None, title="", master_title="Sample Current I(E)", smooth=0):
338         if type(filename) == type(""): 
339                 data = GetData(filename)
340         else:
341                 data = filename
342                 filename = "raw data"
343
344         if (title == ""):
345                 title = BaseName(filename)
346
347
348         if len(data) <= 0:
349                 return data
350
351
352         if (smooth > 0):
353                 if type(data) == type([]):
354                         for i in range(0, smooth[0]):
355                                 data = Smooth(data, m=smooth[1])
356                 else:                   
357                         data = Smooth(data, m = smooth)
358
359         if calibrate: 
360                 data = CalibrateData(data)
361                 units = ["V", "uA"]
362         else:
363                 units = ["DAC counts", "ADC counts"]
364
365         if not normalise:
366                 gnuplot("set ylabel \"I(U) ("+str(units[1])+")\"")
367         else:
368                 data = MaxNormalise(data)
369                 gnuplot("set ylabel \"I(U) (arb. units)\"")
370
371         if (output != None and type(output) == type("")):
372                 gnuplot("set term png size 640,480")
373                 gnuplot("set output \""+str(output)+"\"")
374
375         gnuplot("set title \""+str(master_title)+"\"")
376         gnuplot("set xlabel \"U ("+str(units[0])+")\"")
377
378
379         ymax = 0.005 + 1.2 * max(data, key=lambda e : e[2])[2]
380         ymin = -0.005 + 1.2 * min(data, key=lambda e : e[2])[2]
381         gnuplot("set yrange ["+str(ymin)+":"+str(ymax)+"]")
382
383         #d = Derivative(data, 1, 2, step=step)
384
385         plotList = []
386         
387         plotList.append(Gnuplot.Data(data, using="2:3", with_=with_,title=title))
388         time.sleep(0.1)
389         if (show_error):
390                 error1 = copy.deepcopy(data)
391                 error2 = copy.deepcopy(data)
392                 for i in range(len(data)):
393                         #print str(data[i])
394                         error1[i][2] -= 0.50*float(data[i][3])
395                         error2[i][2] += 0.50*float(data[i][3])
396                 plotList.append(Gnuplot.Data(error1, using="2:3", with_=w,title="Error : Low bound"))
397                 plotList.append(Gnuplot.Data(error2, using="2:3", with_=w, title="Error : Upper bound"))
398         
399         if plot != None:
400                 
401                 plot(*plotList)
402                 if (output != None and type(output) == type("")):
403                         gnuplot("set term wxt") 
404                 return data
405         else:
406                 return plotList
407
408 def ReadParameters(filename):
409         parameters = odict.odict()
410         input_file = open(filename, "r")
411         for line in input_file:
412                 k = line.split("=")
413                 item = None
414                 #print str(k)
415                 if (len(k) >= 2):
416                         item = k[0].strip("# \r\n")
417                         value = k[1].strip("# \r\n")
418                         if (item in parameters):
419                                 parameters[item] = value
420                         else:
421                                 parameters.update({str(item) : value})
422         input_file.close()
423         return parameters
424
425 def PlotParameters(filename):
426         ReadParameters(filename)
427
428 def Smooth(data, m, k=2):
429         smooth = copy.deepcopy(data)
430         for i in range(len(smooth)):
431                 count = 0
432                 smooth[i][k] = 0.0
433                 for j in range(i-m,i+m):
434                         if j >= 0 and j < len(smooth):
435                                 count += 1
436                                 smooth[i][k] += data[j][k]
437                 if count > 0:
438                         smooth[i][k] = smooth[i][k] / float(count)
439                 else:
440                         smooth[i][k] = data[i][k]
441
442         return smooth
443
444 def PeakFind(data, k=2,threshold=0.00, inflection=0):
445         results = []
446         for i in range(len(data)):
447                 if i == 0 or i == len(data)-1:
448                         continue
449                 #if abs(data[i][k]) < threshold * abs(max(data, key = lambda e : abs(e[k]))[k]):
450                 #       continue
451                         
452                 left = data[i-1][k] - data[i][k]
453                 right = data[i+1][k] - data[i][k]
454                 if abs(left) < threshold*abs(data[i][k]):
455                         continue
456                 if abs(right) < threshold*abs(data[i][k]):
457                         continue
458                 if left*right > 0: 
459                         results.append(data[i] + [inflection])
460
461         if inflection > 0:
462                 results += PeakFind(Derivative(data), k=k, threshold=threshold, inflection=inflection-1)
463
464         return results
465
466 def SmoothPeakFind(data, a=1, k=2, ap=DoNothing, stop=10,smooth=5, inflection=0):
467         s = data        
468         #results = []
469         
470         peakList = []
471
472         m = 0
473         while m < stop:
474                 #results.append([])     
475                 peaks = PeakFind(ap(s),k=k, inflection=inflection)
476                 #print "m = " +str(m)
477                 for p in peaks:
478                         add = [m]
479                         [add.append(f) for f in p]
480                         
481                         if m == 0:
482                                 #print "*New peak at " + str(p)
483                                 peakList.append([add])
484                         else:
485                                 score = []
486                                 for i in range(len(peakList)):
487                                         p2 = peakList[i][len(peakList[i])-1]
488                                         if m - p2[0] > 1:
489                                                 continue
490                                         score.append([i, abs(p[a] - p2[1+a])])
491                         
492                                 score.sort(key = lambda e : e[1])
493                                 if len(score) == 0 or score[0][1] > 100:
494                                         #print "New peak at " + str(p)
495                                         peakList.append([add])          
496                                 else:
497                                         #print "Peak exists near " + str(p) + " ("+str(score[0][1])+") " + str(peakList[score[0][0]][len(peakList[score[0][0]])-1])
498                                         peakList[score[0][0]].append(add)
499                                         
500                                 
501                                         
502                         #results.append([m, []])
503                         #[results[len(results)-1].append(f) for f in p]
504                 m += 1
505                 s = Smooth(s, m=smooth,k=k)
506
507         #results.sort(key = lambda e : e[2])
508         
509         #peaks = []
510         return peakList
511         
512                         
513
514 def PlotPeaks(peaks, calibrate=True, with_="lp", plot=gnuplot.replot):
515
516         plotList = []
517         for p in peaks:
518                 p.append(copy.deepcopy(p[len(p)-1]))
519
520                 p[len(p)-1][0] += 1
521                 
522         
523                 #print "Adding " + str(p) + " to list"
524                 if len(p) >= 0:
525                         l = p[len(p)-1]
526                         if l[len(l)-1] < 1:
527                                 with_ = with_.split(" lt")[0] + " lt 9" 
528                         plotList.append(Gnuplot.Data(p, using="3:1", with_=with_))
529                         
530         
531
532         if len(plotList) > 0 and plot != None:
533                 plot(*plotList)
534                 time.sleep(0.2)
535                 
536         #print str(plotList)
537         #for p in peaks:
538         #       p = p[0:len(p)-1]
539         return plotList
540                 
541 def main():
542         
543         if (len(sys.argv) < 2):
544                 sys.stderr.write(sys.argv[0] + " - Require arguments (filename)\n")
545                 return 1
546
547         i = 1
548         plotFunc = ShowTCS
549         normalise = True
550         title = ""
551         master_title = ""
552         smooth=0
553         with_="lp"
554         while i < len(sys.argv):
555                 if sys.argv[i] == "--raw":
556                         plotFunc = ShowData
557                 elif sys.argv[i] == "--tcs":
558                         plotFunc = ShowTCS #lambda e : ShowTCS(e, show_peak=False)
559                 elif sys.argv[i] == "--output":
560                         if i+1 >= len(sys.argv):
561                                 sys.stderr.write("Need argument for "+sys.argv[i]+" switch\n")
562                                 sys.exit(1)
563                         gnuplot("set term postscript colour")
564                         gnuplot("set output \""+sys.argv[i+1]+"\"")
565                         i += 1
566                 elif sys.argv[i] == "--wxt":
567                         gnuplot("set term wxt")
568                 elif sys.argv[i] == "--normalise":
569                         normalise = True
570                 elif sys.argv[i] == "--unnormalise":
571                         normalise = False
572                 elif sys.argv[i] == "--title":
573                         if i+1 >= len(sys.argv):
574                                 sys.stderr.write("Need argument for "+sys.argv[i]+" switch\n")
575                                 sys.exit(1)
576                         title = sys.argv[i+1]
577                         i += 1
578                 elif sys.argv[i] == "--master_title":
579                         if i+1 >= len(sys.argv):
580                                 sys.stderr.write("Need argument for "+sys.argv[i]+" switch\n")
581                                 sys.exit(1)
582                         master_title = sys.argv[i+1]
583                         i += 1
584                 elif sys.argv[i] == "--smooth":
585                         if i+1 >= len(sys.argv):
586                                 sys.stderr.write("Need argument for "+sys.argv[i]+" switch\n")
587                                 sys.exit(1)
588                         smooth = sys.argv[i+1]  
589                         smooth = map(int, smooth.split("x"))
590                         if len(smooth) <= 1:
591                                 smooth = smooth[0]
592                         i += 1
593                 elif sys.argv[i] == "--with":
594                         if i+1 >= len(sys.argv):
595                                 sys.stderr.write("Need argument for "+sys.argv[i]+" switch\n")
596                                 sys.exit(1)
597                         with_ = sys.argv[i+1]
598                         i += 1
599                 elif sys.argv[i] == "--output":
600                         if i+1 >= len(sys.argv):
601                                 sys.stderr.write("Need argument for "+sys.argv[i]+" switch\n")
602                                 sys.exit(1)
603                         gnuplot("set term postscript colour")
604                         gnuplot("set output \""+str(argv[i+1])+"\"")
605                 else:
606                         plotFunc(sys.argv[i], plot=gnuplot.replot, normalise=normalise, title=title, master_title=master_title, smooth=smooth, with_=with_)
607
608                 i += 1
609
610         print "Done. Press enter to exit, or type name of file to save as."
611         out = sys.stdin.readline().strip("\t\r\n #")
612         if out != "":
613                 gnuplot("set term postscript colour enhanced \"Arial Bold,18\"")
614                 gnuplot("set output \""+out+"\"")
615                 gnuplot.replot()
616
617
618 def ModelTCS(f, sigma, Emin, Emax, dE):
619         data = []
620         E = Emin
621         while E < Emax:
622                 S = (1 - sigma(0))*f(-E) + FuncIntegrate(lambda e : f(e - E) * FuncDerivative(sigma, E, dE), Emin, Emax, dE)
623                 data.append([0.00, E, S,0.00])
624                 E += dE
625         return data
626
627 def IntegrateTCS(data, imin, imax=0, di=1):
628         i = imin
629         if imax == 0:
630                 imax = len(data)-1
631         total = 0.0
632         
633         while i < imax:
634                 total += data[i][2] * (data[i+1][1] - data[i][1])
635                 i += di
636         return total
637
638 def FuncIntegrate(f, xmin, xmax, dx):
639         x = xmin
640         total = 0.0
641         while x <= xmax:
642                 total += f(x) * dx
643                 x += dx
644         return total
645
646 def FuncDerivative(f, x, dx):
647         return 0.50*(f(x+dx) - f(x-dx))/dx
648
649 def FitTCS(data, min_mse=1e-4, max_fail=100, max_adjust=4,divide=10, plot=gnuplot.plot,smooth=0):
650         if type(data) == type(""):
651                 d = GetData(data)
652                 d = CalibrateData(d)
653                 d = MaxNormalise(d)
654                 d = Derivative(d)
655         else:
656                 d = data
657
658         if smooth != 0:
659                 if type(smooth) == type([]):
660                         for _ in range(smooth[0]):
661                                 d = Smooth(d, m=smooth[1])              
662                 else:
663                         d = Smooth(d, m=smooth)
664
665         
666
667         plotItems = ShowTCS(d, raw=False,smooth=smooth,plot=None)
668         plotItems.append(None)
669         
670         peaks = SmoothPeakFind(d, smooth=5, stop=1, inflection=0)
671         peaks.sort(key = lambda e : e[len(e)-1][1])
672
673         fits = []
674
675         for i in range(0,len(peaks)):
676                 
677                 p = peaks[i]
678                 l = p[len(p)-1]
679                 if l[len(l)-1] == 0:
680                         fits.append([l[3], l[2], 1.0])
681                 else:
682                         if i-2 >= 0:
683                                 l = peaks[i-2][len(peaks[i-2])]
684                                 if l[len(l)-1] == 0:
685                                         fits.append([l[3], l[2], 1.0])
686                         if i+2 <= len(peaks)-1:
687                                 l = peaks[i+2][len(peaks[i+2])]
688                                 if l[len(l)-1] == 0:
689                                         fits.append([l[3], l[2], 1.0])
690                         
691         for i in range(len(fits)):
692                 left = 2.0
693                 right = 2.0
694                 if i > 0:
695                         left = fits[i-1][1] - fits[i][1]
696                 if i < len(fits)-1:
697                         right = fits[i+1][1] - fits[i][1]
698
699                 fits[i][2] = min([abs(0.5*left), abs(0.5*right)])
700                         
701
702         #print "Fits are " + str(fits)
703         #stdin.readline()
704
705         
706
707         def tcs(E):
708                 total = 0.0
709                 for f in fits:
710                         dt = f[0] * gaussian(E - f[1], f[2])
711                         #print " Increase total by " + str(dt)
712                         total += dt
713                 #print "tcs returns " + str(total)
714                 return total
715
716         mse = 1
717         old_mse = 1
718         cycle = 0
719         failcount = 0
720
721         
722         adjust = 1.0
723         
724         iterations = 0
725                 
726         while failcount < max_fail and mse > min_mse:
727                 i = random.randint(0, len(fits)-1)                      
728                 j = random.randint(0, len(fits[i])-1)
729                         #while j == 1:
730                         #       j = random.randint(0, len(fits[i])-1)
731
732                 #print "Adjust " + str(i) + ","+str(j) + ": Iteration " + str(iterations) + " mse: " + str(mse)
733                 
734                 old = fits[i][j]
735                 old_mse = mse
736
737                 fits[i][j] += adjust * (random.random() - 0.50)
738                 if i == 2:
739                         while fits[i][j] <= 0.0005:
740                                 fits[i][j] = adjust * (random.random() - 0.50)
741
742
743                 model = table(lambda e : [0.00, e, tcs(e), 0.00], 0, 16.8, divide*d[len(d)-1][1]/(len(d)))
744                 mse = MeanSquareError(model, d[0::divide])
745                 if mse >= old_mse:
746                         fits[i][j] = old
747                         failcount += 1
748                         if failcount > max_fail / 2:
749                                 if adjust > 1.0/(2.0**max_adjust):
750                                         adjust /= 2
751                         mse = old_mse
752                 else:
753                                 #adjust /= 2.0
754                         failcount = 0
755
756                         
757                 iterations += 1
758         
759
760         #model = table(lambda e : [0.00, e, tcs(e), 0.00], 0, 16.8, 16.8/len(d))
761         plotItems[len(plotItems)-1] = Gnuplot.Data(model, using="2:3", with_="l lt 3", title="model")
762         time.sleep(0.1)
763
764         fits.sort(key = lambda e : e[0] * gaussian(0, e[2]), reverse=True)
765         
766         if plot != None:
767                 gnuplot("set title \"MSE = "+str(mse)+"\\nfailcount = "+str(failcount)+"\\nadjust = "+str(adjust)+"\"")
768                 gnuplot.plot(*plotItems)
769
770                 return [fits, model]
771         else:
772                 return [fits, model,plotItems]
773
774
775         #return model
776
777 def SaveFit(filename, fit):
778         out = open(filename, "w", 0)
779         out.write("# TCS Fit\n")
780         
781         for f in fit:
782                 out.write(str(f[0]) + "\t" + str(f[1]) + "\t" + str(f[2]) + "\n")
783         
784         out.close()
785
786
787 def LoadFit(filename):
788         infile = open(filename, "r")
789         if (infile.readline() != "# TCS Fit\n"):
790                 sys.stderr.write("Error loading fit from file " + str(filename) + "\n")
791                 sys.exit(0)
792
793         fit = []
794         while True:
795                 f = infile.readline().strip("# \r\n\t").split("\t")
796                 if len(f) != 3:
797                         break
798                 fit.append(map(float, f))
799                 
800         infile.close()
801         fit.sort(key = lambda e : e[0] * gaussian(0, e[2]), reverse=True)
802         #def model(e):
803         #       total = 0.0
804         #       for f in fit:
805         #               total += f[0] * gaussian(e - f[1], f[2])
806         #       return total
807
808         #return table(lambda e : [0.0, e, model(e), 0.0], 0.0, 16.8, 16.8/400)
809         return fit
810
811 def MeanSquareError(model, real, k = 2):
812         mse = 0.0
813         for i in range(len(real)):
814                 
815                 mse += (model[i][k] - real[i][k])**2
816                 
817         mse /= len(model)
818         return mse
819
820 def delta(x):
821         if (x == 0):
822                 return 1.0
823         else:
824                 return 0.0
825
826 def table(f, xmin, xmax, dx):
827         result = []
828         x = xmin
829         while (x <= xmax):
830                 result.append(f(x))
831                 x += dx
832         return result
833
834 def gaussian(x, sigma):
835         if (sigma == 0.0):
836                 return 0.0
837         return math.exp(- (x**2.0)/(2.0 * sigma**2.0)) / (sigma * (2.0 * math.pi)**0.50)
838
839 def step(x, sigma, T):
840         if T == 0:
841                 return 1.0
842         return 1.0 / (math.exp((x - sigma)/T) + 1.0)
843
844 if __name__ == "__main__":
845         sys.exit(main())

UCC git Repository :: git.ucc.asn.au