Surface NMR processing and inversion GUI

filtfilt.py 2.2KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from numpy import vstack, hstack, eye, ones, zeros, linalg, \
  2. newaxis, r_, flipud, convolve, matrix, array
  3. from scipy.signal import lfilter
  4. def lfilter_zi(b,a):
  5. #compute the zi state from the filter parameters. see [Gust96].
  6. #Based on:
  7. # [Gust96] Fredrik Gustafsson, Determining the initial states in forward-backward
  8. # filtering, IEEE Transactions on Signal Processing, pp. 988--992, April 1996,
  9. # Volume 44, Issue 4
  10. n=max(len(a),len(b))
  11. zin = ( eye(n-1) - hstack( (-a[1:n,newaxis],
  12. vstack((eye(n-2), zeros(n-2))))))
  13. zid= b[1:n] - a[1:n]*b[0]
  14. zi_matrix=linalg.inv(zin)*(matrix(zid).transpose())
  15. zi_return=[]
  16. #convert the result into a regular array (not a matrix)
  17. for i in range(len(zi_matrix)):
  18. zi_return.append(float(zi_matrix[i][0]))
  19. return array(zi_return)
  20. def filtfilt(b,a,x):
  21. #For now only accepting 1d arrays
  22. ntaps=max(len(a),len(b))
  23. edge=ntaps*3
  24. if x.ndim != 1:
  25. raise ValueError("Filiflit is only accepting 1 dimension arrays.")
  26. #x must be bigger than edge
  27. if x.size < edge:
  28. raise ValueError("Input vector needs to be bigger than 3 * max(len(a),len(b).")
  29. if len(a) < ntaps:
  30. a=r_[a,zeros(len(b)-len(a))]
  31. if len(b) < ntaps:
  32. b=r_[b,zeros(len(a)-len(b))]
  33. zi=lfilter_zi(b,a)
  34. #Grow the signal to have edges for stabilizing
  35. #the filter with inverted replicas of the signal
  36. s=r_[2*x[0]-x[edge:1:-1],x,2*x[-1]-x[-1:-edge:-1]]
  37. #in the case of one go we only need one of the extrems
  38. # both are needed for filtfilt
  39. (y,zf)=lfilter(b,a,s,-1,zi*s[0])
  40. (y,zf)=lfilter(b,a,flipud(y),-1,zi*y[-1])
  41. return flipud(y[edge-1:-edge+1])
  42. if __name__=='__main__':
  43. from scipy.signal import butter
  44. from scipy import sin, arange, pi, randn
  45. from pylab import plot, legend, show, hold
  46. t=arange(-1,1,.01)
  47. x=sin(2*pi*t*.5+2)
  48. #xn=x + sin(2*pi*t*10)*.1
  49. xn=x+randn(len(t))*0.05
  50. [b,a]=butter(3,0.05)
  51. z=lfilter(b,a,xn)
  52. y=filtfilt(b,a,xn)
  53. plot(x,'c')
  54. hold(True)
  55. plot(xn,'k')
  56. plot(z,'r')
  57. plot(y,'g')
  58. legend(('original','noisy signal','lfilter - butter 3 order','filtfilt - butter 3 order'))
  59. hold(False)
  60. show()