251223392c461f581245a6765d5b8bd2f9cf0c45
[matches/honours.git] / thesis / figures / tcs / plots / process.py~
1 #!/usr/bin/python -u
2
3 #
4 # @file process.py
5 # @purpose Process TCS data
6 #               Takes S(E) = dI/dE
7 # @author Sam Moore
8 # @date August 2012
9 #
10
11 import sys
12 import os
13 import re # Regular expressions - for removing comments
14 import odict #ordered dictionary
15 import copy
16
17 import Gnuplot, Gnuplot.funcutils
18 import string
19 import time
20 import math
21 import cmath
22 import random
23 import numpy
24
25 gnuplot = Gnuplot.Gnuplot()
26
27 def Reset():
28         gnuplot = Gnuplot.Gnuplot()
29
30 def FindDataFiles(directory=".", depth=1, result=None):
31         if result == None:
32                 result = []
33         
34         for f in os.listdir(directory):
35                 if os.path.isdir(directory+"/"+str(f)):
36                         if depth > 1:
37                                 result += FindDataFiles(directory+"/"+str(f), depth-1, result)
38                         continue
39                 s = f.split(".")
40                 if (len(s) == 2 and s[1] == "dat"):
41                         result.append(directory+"/"+str(f))
42
43         return result
44
45 def BaseName(f):
46         a = f.split("/")
47         return a[len(a)-1]
48         
49
50 def DirectoryName(f, start=0,back=1):
51         a = f.split("/")
52         return string.join(a[start:(len(a)-back)], "/")
53
54 def GetData(filename, key=1):
55
56         if type(filename) != type(""):
57                 if type(filename) == type([]):
58                         return filename
59                 else:
60                         return [[0,0,0,0]]
61         
62
63         if os.path.isdir(filename):
64                 if os.path.exists(filename.strip("/")+"/average.dat"):
65                         os.remove(filename.strip("/")+"/average.dat")
66                 AverageAllData(filename)
67                 return GetData(filename.strip("/")+"/average.dat")
68
69         input_file = open(filename, "r")
70         data = {}
71         for line in input_file:
72                 line = re.sub("#.*", "", line).strip("\r\n ")
73                 if len(line) == 0:
74                         continue
75
76                 line = map(lambda e : float(e), line.split("\t"))
77                 
78                 if line[key] in data:
79                         for i in range(len(line)):
80                                 data[line[key]][0][i] += line[i]
81                         data[line[key]][1] += 1
82                 else:
83                         data.update({line[key] : [line, 1]})                    
84                 
85         d = map(lambda e : map(lambda f : float(f) / float(e[1][1]), e[1][0]), data.items())
86         d.sort(key = lambda e : e[key])
87         #for l in d:
88         #       print str(l)
89         return d
90
91 def DoNothing(data):
92         return data
93
94
95 def GetDataSets(directory="."):
96         data_sets = []
97         for f in os.listdir(directory):
98                 if os.path.isdir(directory+"/"+str(f)) == False:
99                         if (len(f.split(".")) > 1 and f.split(".")[1] == "dat"):
100                                 d = GetData(directory+"/"+str(f))
101                                 if len(d) > 0:
102                                         data_sets.append(d)
103         return data_sets        
104
105
106
107 def Derivative(data, a=1, b=2, sigma=None,step=1):
108         result = [[]]
109         n = 0
110         dI = [0,0]
111         dE = [0,0]
112         
113         for i in range(0, len(data),step):
114                 result[len(result)-1] = [d for d in data[i]]
115                 if (i >= step):
116                         dE[0] = data[i][a] - data[i-step][a]
117                         dI[0] = data[i][b] - data[i-step][b]
118                 else:
119                         dI[0] = None
120
121                 if (i < len(data)-step):
122                         dE[1] = data[i+step][a] - data[i][a]
123                         dI[1] = data[i+step][b] - data[i][b]
124                 else:
125                         dI[1] = None            
126
127                 if sigma != None:
128                         #print str(data[i]) + " ["+str(sigma)+"] = " + str(data[i][int(abs(sigma))])
129                         if sigma < 0:
130                                 if dI[0] != None: dI[0] -= 0.5*data[i][int(abs(sigma))]
131                                 if dI[1] != None: dI[1] -= 0.5*data[i][int(abs(sigma))]
132                         else:
133                                 if dI[0] != None: dI[0] += 0.5*data[i][int(abs(sigma))]
134                                 if dI[1] != None: dI[1] += 0.5*data[i][int(abs(sigma))]
135
136                 deltaE = 0.0
137                 deltaI = 0.0
138                 count = 0
139                 if dI[0] != None:
140                         deltaE += dE[0]
141                         deltaI += dI[0]
142                         count += 1
143                 if dI[1] != None:
144                         deltaE += dE[1]
145                         deltaI += dI[1]
146                         count += 1
147
148                 if (count > 0):
149                         deltaI /= float(count)
150                         deltaE /= float(count)
151
152
153                         if (deltaE != 0):       
154                                 result[len(result)-1][b] = (deltaI / deltaE)
155                         else:
156                                 result[len(result)-1][b] = 0.0
157                 else:
158                         result[len(result)-1][b] = 0.0
159                 result.append([])
160                         
161         return result[0:len(result)-1]
162
163 def MaxNormalise(data, u=2):    
164         result = copy.deepcopy(data)
165         if (len(data) <= 0):
166                 return result
167         maxval = max(data, key = lambda e : e[u])[u]
168         
169         if maxval == 0:
170                 return result
171
172         for d in result:
173                 d[u] = d[u] / maxval
174                 
175         return result
176         
177 def Average(data_sets, u=1): 
178         avg = odict.odict()
179         for t in data_sets:
180                 for p in t:
181                         if p[u] in avg:
182                                 #print "Already have " + str(p[u])
183                                 avg[p[u]][1] += 1
184                                 for i in range(0, len(p)):
185                                         avg[p[u]][0][i] += p[i]
186                         else:
187                                 #print "Create key for " + str(p[u])
188                                 avg.update({p[u] : [p, 1]})
189
190         for a in avg.keys():
191                 for i in range(0, len(avg[a][0])):
192                         avg[a][0][i] /= float(avg[a][1])
193         return map(lambda e : e[1][0], sorted(avg.items(), key = lambda e : e[0]))
194
195 def FullWidthAtHalfMax(data, u=1):
196         maxval = max(data, key = lambda e : e[u])
197         peak = data.index(maxval)
198         maxval = maxval[0]
199         lhs = None
200         rhs = None
201         for i in range(1, len(data)/2):
202                 if lhs == None:
203                         if (peak-i > 0 and data[peak-i] < 0.50*maxval):
204                                 lhs = data[peak-i][u]
205                 if rhs == None:
206                         if (peak+i < len(data) and data[peak+i] < 0.50*maxval):
207                                 rhs = peak+i
208                 if lhs != None and rhs != None:
209                         break
210         if rhs == None or lhs == None:
211                 return abs(data[len(data)-1][0] - data[0][0])
212         else:
213                 return abs(rhs - lhs)
214
215 def SaveData(filename, data):
216         out = open(filename, "w", 0)
217         for a in data:
218                 for i in range(0, len(a)):
219                         out.write(str(a[i]))
220                         if (i < len(a) - 1):
221                                 out.write("\t")
222                 out.write("\n")
223
224 def AverageAllData(directory, save=None, normalise=True):
225         data_sets = []
226         if save == None: save = directory+"/average.dat"
227         for f in FindDataFiles(directory):
228                 d = GetData(f)
229                 if normalise:
230                         d = MaxNormalise(d)
231                 data_sets.append(d)
232         
233         a = Average(data_sets)
234         SaveData(save, a)
235         return a
236
237 def CalibrateData(original, ammeter_scale=1e-6):
238         data = copy.deepcopy(original)
239         for i in range(0, len(data)):
240                 data[i][1] = 16.8 * float(data[i][1]) / 4000.0
241                 data[i][2] = ammeter_scale * 0.170 * float(data[i][2]) / 268.0
242                 data[i][3] = ammeter_scale * 0.170 * float(data[i][3]) / 268.0
243         return data
244
245 def ShowTCS(filename, raw=True,calibrate=True, normalise=False, show_error=False, plot=gnuplot.plot,with_="lp", step=1, output=None, title="", master_title="", smooth=0, show_peak=False, inflection=1):
246
247         if raw == False:
248                 calibrate = False
249                 normalise = False
250
251         if type(filename) == type(""): 
252                 data = GetData(filename)
253         else:
254                 data = filename
255                 filename = "tcs data"
256
257         if (title == ""):
258                 title = BaseName(filename)
259
260         if (len(data) <= 0):
261                 return data
262
263         if (smooth > 0):
264                 if type(smooth) == type([]):
265                         for i in range(smooth[0]):
266                                 data = Smooth(data, m=smooth[1])
267                 else:                   
268                         data = Smooth(data, m=smooth)
269
270
271         if calibrate: 
272                 data = CalibrateData(data)
273                 units = ["V", "uA / V"]
274         else:
275                 units = ["DAC counts", "ADC counts / DAC counts"]
276
277         if not normalise:
278                 gnuplot("set ylabel \"dI(E)/dE ("+str(units[1])+")\"")
279         else:
280                 data = MaxNormalise(data)
281                 gnuplot("set ylabel \"dI(E)/dE (normalised)\"")
282
283         if (output != None and type(output) == type("")):
284                 gnuplot("set term png size 640,480")
285                 gnuplot("set output \""+str(output)+"\"")
286
287         if master_title == "":
288                 master_title = "Total Current Spectrum S(E)"
289                 if type(filename) == type("") and plot == gnuplot.plot:
290                         if filename != "tcs data":
291                                 p = ReadParameters(filename)
292                                 if "Sample" in p:
293                                         master_title += "\\nSample: "+p["Sample"]
294
295         gnuplot("set title \""+str(master_title)+"\"")
296         gnuplot("set xlabel \"U ("+str(units[0])+")\"")
297
298
299         if raw:
300                 d = Derivative(data, 1, 2, step=step)
301         else:
302                 d = data
303
304         ymax = 0.01 + 1.2 * max(d, key=lambda e : e[2])[2]
305         ymin = -0.01 + 1.2 * min(d, key=lambda e : e[2])[2]
306         gnuplot("set yrange ["+str(ymin)+":"+str(ymax)+"]")
307
308         plotList = []
309         plotList.append(Gnuplot.Data(d, using="2:3", with_=with_,title=title))
310         
311         if (show_error):
312                 error1 = Derivative(data, 1, 2, -3,step=step)
313                 error2 = Derivative(data, 1, 2, +3,step=step)
314                 plotList.append(Gnuplot.Data(error1, using="2:3", with_=w,title="-sigma/2"))
315                 plotList.append(Gnuplot.Data(error2, using="2:3", with_=w, title="+sigma/2"))
316
317         if (show_peak):
318                 peak = SmoothPeakFind(d, ap=DoNothing, stop=1, inflection=inflection)
319                 plotList += PlotPeaks(peak,with_="l lt -1", plot=None)
320                 
321         
322
323         if (plot != None):
324                 plot(*plotList)
325                 time.sleep(0.2)
326         
327         if (output != None and type(output) == type("")):
328                 gnuplot("set term wxt")
329
330         if (plot == None):
331                 return plotList
332         return data
333
334 def ShowData(filename,calibrate=True, normalise=False, show_error=False, plot=gnuplot.plot,with_="lp", step=1, output=None, title="", master_title="Sample Current I(E)", smooth=0):
335         if type(filename) == type(""): 
336                 data = GetData(filename)
337         else:
338                 data = filename
339                 filename = "raw data"
340
341         if (title == ""):
342                 title = BaseName(filename)
343
344
345         if len(data) <= 0:
346                 return data
347
348
349         if (smooth > 0):
350                 if type(data) == type([]):
351                         for i in range(0, smooth[0]):
352                                 data = Smooth(data, m=smooth[1])
353                 else:                   
354                         data = Smooth(data, m = smooth)
355
356         if calibrate: 
357                 data = CalibrateData(data)
358                 units = ["V", "uA"]
359         else:
360                 units = ["DAC counts", "ADC counts"]
361
362         if not normalise:
363                 gnuplot("set ylabel \"I(E) ("+str(units[1])+")\"")
364         else:
365                 data = MaxNormalise(data)
366                 gnuplot("set ylabel \"I(E) (normalised)\"")
367
368         if (output != None and type(output) == type("")):
369                 gnuplot("set term png size 640,480")
370                 gnuplot("set output \""+str(output)+"\"")
371
372         gnuplot("set title \""+str(master_title)+"\"")
373         gnuplot("set xlabel \"U ("+str(units[0])+")\"")
374
375
376         ymax = 0.005 + 1.2 * max(d, key=lambda e : e[2])[2]
377         ymin = -0.005 + 1.2 * min(d, key=lambda e : e[2])[2]
378         gnuplot("set yrange ["+str(ymin)+":"+str(ymax)+"]")
379
380         #d = Derivative(data, 1, 2, step=step)
381
382         plotList = []
383         
384         plotList.append(Gnuplot.Data(data, using="2:3", with_=with_,title=title))
385         time.sleep(0.1)
386         if (show_error):
387                 error1 = copy.deepcopy(data)
388                 error2 = copy.deepcopy(data)
389                 for i in range(len(data)):
390                         #print str(data[i])
391                         error1[i][2] -= 0.50*float(data[i][3])
392                         error2[i][2] += 0.50*float(data[i][3])
393                 plotList.append(Gnuplot.Data(error1, using="2:3", with_=w,title="Error : Low bound"))
394                 plotList.append(Gnuplot.Data(error2, using="2:3", with_=w, title="Error : Upper bound"))
395         
396         if plot != None:
397                 
398                 plot(*plotList)
399                 if (output != None and type(output) == type("")):
400                         gnuplot("set term wxt") 
401                 return data
402         else:
403                 return plotList
404
405 def ReadParameters(filename):
406         parameters = odict.odict()
407         input_file = open(filename, "r")
408         for line in input_file:
409                 k = line.split("=")
410                 item = None
411                 #print str(k)
412                 if (len(k) >= 2):
413                         item = k[0].strip("# \r\n")
414                         value = k[1].strip("# \r\n")
415                         if (item in parameters):
416                                 parameters[item] = value
417                         else:
418                                 parameters.update({str(item) : value})
419         input_file.close()
420         return parameters
421
422 def PlotParameters(filename):
423         ReadParameters(filename)
424
425 def Smooth(data, m, k=2):
426         smooth = copy.deepcopy(data)
427         for i in range(len(smooth)):
428                 count = 0
429                 smooth[i][k] = 0.0
430                 for j in range(i-m,i+m):
431                         if j >= 0 and j < len(smooth):
432                                 count += 1
433                                 smooth[i][k] += data[j][k]
434                 if count > 0:
435                         smooth[i][k] = smooth[i][k] / float(count)
436                 else:
437                         smooth[i][k] = data[i][k]
438
439         return smooth
440
441 def PeakFind(data, k=2,threshold=0.00, inflection=0):
442         results = []
443         for i in range(len(data)):
444                 if i == 0 or i == len(data)-1:
445                         continue
446                 #if abs(data[i][k]) < threshold * abs(max(data, key = lambda e : abs(e[k]))[k]):
447                 #       continue
448                         
449                 left = data[i-1][k] - data[i][k]
450                 right = data[i+1][k] - data[i][k]
451                 if abs(left) < threshold*abs(data[i][k]):
452                         continue
453                 if abs(right) < threshold*abs(data[i][k]):
454                         continue
455                 if left*right > 0: 
456                         results.append(data[i] + [inflection])
457
458         if inflection > 0:
459                 results += PeakFind(Derivative(data), k=k, threshold=threshold, inflection=inflection-1)
460
461         return results
462
463 def SmoothPeakFind(data, a=1, k=2, ap=DoNothing, stop=10,smooth=5, inflection=0):
464         s = data        
465         #results = []
466         
467         peakList = []
468
469         m = 0
470         while m < stop:
471                 #results.append([])     
472                 peaks = PeakFind(ap(s),k=k, inflection=inflection)
473                 #print "m = " +str(m)
474                 for p in peaks:
475                         add = [m]
476                         [add.append(f) for f in p]
477                         
478                         if m == 0:
479                                 #print "*New peak at " + str(p)
480                                 peakList.append([add])
481                         else:
482                                 score = []
483                                 for i in range(len(peakList)):
484                                         p2 = peakList[i][len(peakList[i])-1]
485                                         if m - p2[0] > 1:
486                                                 continue
487                                         score.append([i, abs(p[a] - p2[1+a])])
488                         
489                                 score.sort(key = lambda e : e[1])
490                                 if len(score) == 0 or score[0][1] > 100:
491                                         #print "New peak at " + str(p)
492                                         peakList.append([add])          
493                                 else:
494                                         #print "Peak exists near " + str(p) + " ("+str(score[0][1])+") " + str(peakList[score[0][0]][len(peakList[score[0][0]])-1])
495                                         peakList[score[0][0]].append(add)
496                                         
497                                 
498                                         
499                         #results.append([m, []])
500                         #[results[len(results)-1].append(f) for f in p]
501                 m += 1
502                 s = Smooth(s, m=smooth,k=k)
503
504         #results.sort(key = lambda e : e[2])
505         
506         #peaks = []
507         return peakList
508         
509                         
510
511 def PlotPeaks(peaks, calibrate=True, with_="lp", plot=gnuplot.replot):
512
513         plotList = []
514         for p in peaks:
515                 p.append(copy.deepcopy(p[len(p)-1]))
516
517                 p[len(p)-1][0] += 1
518                 
519         
520                 #print "Adding " + str(p) + " to list"
521                 if len(p) >= 0:
522                         l = p[len(p)-1]
523                         if l[len(l)-1] < 1:
524                                 with_ = with_.split(" lt")[0] + " lt 9" 
525                         plotList.append(Gnuplot.Data(p, using="3:1", with_=with_))
526                         
527         
528
529         if len(plotList) > 0 and plot != None:
530                 plot(*plotList)
531                 time.sleep(0.2)
532                 
533         #print str(plotList)
534         #for p in peaks:
535         #       p = p[0:len(p)-1]
536         return plotList
537                 
538 def main():
539         
540         if (len(sys.argv) < 2):
541                 sys.stderr.write(sys.argv[0] + " - Require arguments (filename)\n")
542                 return 1
543
544         i = 1
545         plotFunc = ShowTCS
546         normalise = False
547         title = ""
548         master_title = ""
549         smooth=0
550         with_="lp"
551         while i < len(sys.argv):
552                 if sys.argv[i] == "--raw":
553                         plotFunc = ShowData
554                 elif sys.argv[i] == "--tcs":
555                         plotFunc = ShowTCS #lambda e : ShowTCS(e, show_peak=False)
556                 elif sys.argv[i] == "--output":
557                         if i+1 >= len(sys.argv):
558                                 sys.stderr.write("Need argument for "+sys.argv[i]+" switch\n")
559                                 sys.exit(1)
560                         gnuplot("set term postscript colour")
561                         gnuplot("set output \""+sys.argv[i+1]+"\"")
562                         i += 1
563                 elif sys.argv[i] == "--wxt":
564                         gnuplot("set term wxt")
565                 elif sys.argv[i] == "--normalise":
566                         normalise = True
567                 elif sys.argv[i] == "--unnormalise":
568                         normalise = False
569                 elif sys.argv[i] == "--title":
570                         if i+1 >= len(sys.argv):
571                                 sys.stderr.write("Need argument for "+sys.argv[i]+" switch\n")
572                                 sys.exit(1)
573                         title = sys.argv[i+1]
574                         i += 1
575                 elif sys.argv[i] == "--master_title":
576                         if i+1 >= len(sys.argv):
577                                 sys.stderr.write("Need argument for "+sys.argv[i]+" switch\n")
578                                 sys.exit(1)
579                         master_title = sys.argv[i+1]
580                         i += 1
581                 elif sys.argv[i] == "--smooth":
582                         if i+1 >= len(sys.argv):
583                                 sys.stderr.write("Need argument for "+sys.argv[i]+" switch\n")
584                                 sys.exit(1)
585                         smooth = sys.argv[i+1]  
586                         smooth = map(int, smooth.split("x"))
587                         if len(smooth) <= 1:
588                                 smooth = smooth[0]
589                         i += 1
590                 elif sys.argv[i] == "--with":
591                         if i+1 >= len(sys.argv):
592                                 sys.stderr.write("Need argument for "+sys.argv[i]+" switch\n")
593                                 sys.exit(1)
594                         with_ = sys.argv[i+1]
595                         i += 1
596                 elif sys.argv[i] == "--output":
597                         if i+1 >= len(sys.argv):
598                                 sys.stderr.write("Need argument for "+sys.argv[i]+" switch\n")
599                                 sys.exit(1)
600                         gnuplot("set term postscript colour")
601                         gnuplot("set output \""+str(argv[i+1])+"\"")
602                 else:
603                         plotFunc(sys.argv[i], plot=gnuplot.replot, normalise=normalise, title=title, master_title=master_title, smooth=smooth, with_=with_)
604
605                 i += 1
606
607         print "Done. Press enter to exit, or type name of file to save as."
608         out = sys.stdin.readline().strip("\t\r\n #")
609         if out != "":
610                 gnuplot("set term postscript colour")
611                 gnuplot("set output \""+out+"\"")
612                 gnuplot.replot()
613
614
615 def ModelTCS(f, sigma, Emin, Emax, dE):
616         data = []
617         E = Emin
618         while E < Emax:
619                 S = (1 - sigma(0))*f(-E) + FuncIntegrate(lambda e : f(e - E) * FuncDerivative(sigma, E, dE), Emin, Emax, dE)
620                 data.append([0.00, E, S,0.00])
621                 E += dE
622         return data
623
624 def IntegrateTCS(data, imin, imax=0, di=1):
625         i = imin
626         if imax == 0:
627                 imax = len(data)-1
628         total = 0.0
629         
630         while i < imax:
631                 total += data[i][2] * (data[i+1][1] - data[i][1])
632                 i += di
633         return total
634
635 def FuncIntegrate(f, xmin, xmax, dx):
636         x = xmin
637         total = 0.0
638         while x <= xmax:
639                 total += f(x) * dx
640                 x += dx
641         return total
642
643 def FuncDerivative(f, x, dx):
644         return 0.50*(f(x+dx) - f(x-dx))/dx
645
646 def FitTCS(data, min_mse=1e-4, max_fail=100, max_adjust=4,divide=10, plot=gnuplot.plot,smooth=0):
647         if type(data) == type(""):
648                 d = GetData(data)
649                 d = CalibrateData(d)
650                 d = MaxNormalise(d)
651                 d = Derivative(d)
652         else:
653                 d = data
654
655         if smooth != 0:
656                 if type(smooth) == type([]):
657                         for _ in range(smooth[0]):
658                                 d = Smooth(d, m=smooth[1])              
659                 else:
660                         d = Smooth(d, m=smooth)
661
662         
663
664         plotItems = ShowTCS(d, raw=False,smooth=smooth,plot=None)
665         plotItems.append(None)
666         
667         peaks = SmoothPeakFind(d, smooth=5, stop=1, inflection=0)
668         peaks.sort(key = lambda e : e[len(e)-1][1])
669
670         fits = []
671
672         for i in range(0,len(peaks)):
673                 
674                 p = peaks[i]
675                 l = p[len(p)-1]
676                 if l[len(l)-1] == 0:
677                         fits.append([l[3], l[2], 1.0])
678                 else:
679                         if i-2 >= 0:
680                                 l = peaks[i-2][len(peaks[i-2])]
681                                 if l[len(l)-1] == 0:
682                                         fits.append([l[3], l[2], 1.0])
683                         if i+2 <= len(peaks)-1:
684                                 l = peaks[i+2][len(peaks[i+2])]
685                                 if l[len(l)-1] == 0:
686                                         fits.append([l[3], l[2], 1.0])
687                         
688         for i in range(len(fits)):
689                 left = 2.0
690                 right = 2.0
691                 if i > 0:
692                         left = fits[i-1][1] - fits[i][1]
693                 if i < len(fits)-1:
694                         right = fits[i+1][1] - fits[i][1]
695
696                 fits[i][2] = min([abs(0.5*left), abs(0.5*right)])
697                         
698
699         #print "Fits are " + str(fits)
700         #stdin.readline()
701
702         
703
704         def tcs(E):
705                 total = 0.0
706                 for f in fits:
707                         dt = f[0] * gaussian(E - f[1], f[2])
708                         #print " Increase total by " + str(dt)
709                         total += dt
710                 #print "tcs returns " + str(total)
711                 return total
712
713         mse = 1
714         old_mse = 1
715         cycle = 0
716         failcount = 0
717
718         
719         adjust = 1.0
720         
721         iterations = 0
722                 
723         while failcount < max_fail and mse > min_mse:
724                 i = random.randint(0, len(fits)-1)                      
725                 j = random.randint(0, len(fits[i])-1)
726                         #while j == 1:
727                         #       j = random.randint(0, len(fits[i])-1)
728
729                 #print "Adjust " + str(i) + ","+str(j) + ": Iteration " + str(iterations) + " mse: " + str(mse)
730                 
731                 old = fits[i][j]
732                 old_mse = mse
733
734                 fits[i][j] += adjust * (random.random() - 0.50)
735                 if i == 2:
736                         while fits[i][j] <= 0.0005:
737                                 fits[i][j] = adjust * (random.random() - 0.50)
738
739
740                 model = table(lambda e : [0.00, e, tcs(e), 0.00], 0, 16.8, divide*d[len(d)-1][1]/(len(d)))
741                 mse = MeanSquareError(model, d[0::divide])
742                 if mse >= old_mse:
743                         fits[i][j] = old
744                         failcount += 1
745                         if failcount > max_fail / 2:
746                                 if adjust > 1.0/(2.0**max_adjust):
747                                         adjust /= 2
748                         mse = old_mse
749                 else:
750                                 #adjust /= 2.0
751                         failcount = 0
752
753                         
754                 iterations += 1
755         
756
757         #model = table(lambda e : [0.00, e, tcs(e), 0.00], 0, 16.8, 16.8/len(d))
758         plotItems[len(plotItems)-1] = Gnuplot.Data(model, using="2:3", with_="l lt 3", title="model")
759         time.sleep(0.1)
760
761         fits.sort(key = lambda e : e[0] * gaussian(0, e[2]), reverse=True)
762         
763         if plot != None:
764                 gnuplot("set title \"MSE = "+str(mse)+"\\nfailcount = "+str(failcount)+"\\nadjust = "+str(adjust)+"\"")
765                 gnuplot.plot(*plotItems)
766
767                 return [fits, model]
768         else:
769                 return [fits, model,plotItems]
770
771
772         #return model
773
774 def SaveFit(filename, fit):
775         out = open(filename, "w", 0)
776         out.write("# TCS Fit\n")
777         
778         for f in fit:
779                 out.write(str(f[0]) + "\t" + str(f[1]) + "\t" + str(f[2]) + "\n")
780         
781         out.close()
782
783
784 def LoadFit(filename):
785         infile = open(filename, "r")
786         if (infile.readline() != "# TCS Fit\n"):
787                 sys.stderr.write("Error loading fit from file " + str(filename) + "\n")
788                 sys.exit(0)
789
790         fit = []
791         while True:
792                 f = infile.readline().strip("# \r\n\t").split("\t")
793                 if len(f) != 3:
794                         break
795                 fit.append(map(float, f))
796                 
797         infile.close()
798         fit.sort(key = lambda e : e[0] * gaussian(0, e[2]), reverse=True)
799         #def model(e):
800         #       total = 0.0
801         #       for f in fit:
802         #               total += f[0] * gaussian(e - f[1], f[2])
803         #       return total
804
805         #return table(lambda e : [0.0, e, model(e), 0.0], 0.0, 16.8, 16.8/400)
806         return fit
807
808 def MeanSquareError(model, real, k = 2):
809         mse = 0.0
810         for i in range(len(real)):
811                 
812                 mse += (model[i][k] - real[i][k])**2
813                 
814         mse /= len(model)
815         return mse
816
817 def delta(x):
818         if (x == 0):
819                 return 1.0
820         else:
821                 return 0.0
822
823 def table(f, xmin, xmax, dx):
824         result = []
825         x = xmin
826         while (x <= xmax):
827                 result.append(f(x))
828                 x += dx
829         return result
830
831 def gaussian(x, sigma):
832         if (sigma == 0.0):
833                 return 0.0
834         return math.exp(- (x**2.0)/(2.0 * sigma**2.0)) / (sigma * (2.0 * math.pi)**0.50)
835
836 def step(x, sigma, T):
837         if T == 0:
838                 return 1.0
839         return 1.0 / (math.exp((x - sigma)/T) + 1.0)
840
841 if __name__ == "__main__":
842         sys.exit(main())

UCC git Repository :: git.ucc.asn.au