Package qudi_hira_analysis
Analytics suite for qubit SPM using FPGA timetaggers
Getting started
Start by creating an instance of the DataHandler
class. Specify the location you
want to load data from (data_folder
), the location you want to save figures to (
figure_folder
) and (optionally) the name of the measurement folder (
measurement_folder
). If a measurement folder is specified, its path will be
combined with the data folder path to form the full path to the measurement data.
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from qudi_hira_analysis import DataHandler
dh = DataHandler(
data_folder=Path("C:/Data"), # Path to the data folder
figure_folder=Path("C:/QudiHiraAnalysis"), # Path to the figure folder
measurement_folder=Path("20230101_NV1") # Name of the measurement folder
)
# Output:
# qudi_hira_analysis.data_handler :: INFO :: Data folder path is C:/Data/20230101_NV1
# qudi_hira_analysis.data_handler :: INFO :: Figure folder path is C:/QudiHiraAnalysis/20230101_NV1
Loading data
To load a specific set of measurements from the data folder, use the
DataHandler.load_measurements()
method. The method takes a string as an argument
and searches for files with the string in the path. The files are lazy-loaded,
so the data is only loaded when it is needed. The method returns a dictionary,
where the keys are the timestamps of the measurements and the values are
MeasurementDataclass
objects.
# Search and lazy-load all pulsed measurements with "odmr" in the path into Dataclasses
odmr_measurements = dh.load_measurements("odmr", pulsed=True)
odmr = odmr_measurements["20230101-0420-00"]
>>> odmr
MeasurementDataclass(timestamp='2023-01-01 04:20:00', filename='odmr.dat')
>>> odmr.data
Controlled variable(Hz) Signal
0 2850.000000 1.001626
1 2850.816327 0.992160
2 2851.632653 0.975900
3 2852.448980 0.990770
4 2853.265306 0.994068
... ... ...
Fitting data
To fit data, call the AnalysisLogic.fit()
method. This method accepts pandas
DataFrames, numpy arrays or pandas Series as inputs. To get the full list of
available fit routines, explore the DataHandler.fit_function
attribute or call
AnalysisLogic.get_all_fits()
.
The fit functions available are:
Dimension | Fit |
---|---|
1D | decayexponential |
biexponential | |
decayexponentialstretched | |
gaussian | |
gaussiandouble | |
gaussianlinearoffset | |
hyperbolicsaturation | |
linear | |
lorentzian | |
lorentziandouble | |
lorentziantriple | |
sine | |
sinedouble | |
sinedoublewithexpdecay | |
sinedoublewithtwoexpdecay | |
sineexponentialdecay | |
sinestretchedexponentialdecay | |
sinetriple | |
sinetriplewithexpdecay | |
sinetriplewiththreeexpdecay | |
2D | twoDgaussian |
x_fit, y_fit, result = dh.fit(x="Controlled variable(Hz)", y="Signal",
fit_function=dh.fit_function.lorentziandouble,
data=odmr.data)
# Plot the data and the fit
ax = sns.scatterplot(x="Controlled variable(Hz)", y="Signal", data=odmr.data, label="Data")
sns.lineplot(x=x_fit, y=y_fit, ax=ax, label="Fit")
>>> print(result.fit_report())
[[Model]]
(((Model(amplitude_function, prefix='l0_') * Model(physical_lorentzian, prefix='l0_')) + Model(constant_function)) + (Model(amplitude_function, prefix='l1_') * Model(physical_lorentzian, prefix='l1_')))
[[Fit Statistics]]
# fitting method = leastsq
# function evals = 57
# data points = 50
# variables = 7
chi-square = 0.00365104
reduced chi-square = 8.4908e-05
Akaike info crit = -462.238252
Bayesian info crit = -448.854091
R-squared = 0.98267842
[[Variables]]
l0_amplitude: -0.22563916 +/- 0.00770009 (3.41%) (init = -0.1766346)
l0_center: 2866.58588 +/- 0.05208537 (0.00%) (init = 2866.327)
l0_sigma: 1.51496671 +/- 0.08457482 (5.58%) (init = 2.834718)
offset: 1.00213733 +/- 0.00189341 (0.19%) (init = 0.9965666)
l1_amplitude: -0.21380290 +/- 0.00772333 (3.61%) (init = -0.1676134)
l1_center: 2873.49077 +/- 0.05478295 (0.00%) (init = 2872.857)
l1_sigma: 1.49704419 +/- 0.08813686 (5.89%) (init = 2.987285)
l0_fwhm: 3.02993343 +/- 0.16914964 (5.58%) == '2*l0_sigma'
l0_contrast: -22.5157920 +/- 0.76625384 (3.40%) == '(l0_amplitude/offset)*100'
l1_fwhm: 2.99408834 +/- 0.17627372 (5.89%) == '2*l1_sigma'
l1_contrast: -21.3346912 +/- 0.76857477 (3.60%) == '(l1_amplitude/offset)*100'
[[Correlations]] (unreported correlations are < 0.100)
C(l0_amplitude, l0_sigma) = +0.6038
C(l1_amplitude, l1_sigma) = +0.6020
C(l0_sigma, offset) = +0.3301
C(offset, l1_sigma) = +0.3250
C(l0_sigma, l1_sigma) = -0.1588
C(l0_center, l1_sigma) = -0.1531
C(l0_sigma, l1_center) = +0.1531
Saving data
To save figures, call the IOHandler.save_figures()
method. By default,
the figures are saved as JPG, PDF, PNG and SVG. This can be changed by setting the
only_jpg
or only_pdf
arguments to True
. All other keyword arguments are passed
to the matplotlib.pyplot.savefig()
function.
# Save the figure to the figure folder specified earlier
dh.save_figures(filepath=Path("odmr"), fig=ax.get_figure(),
only_pdf=True, bbox_inches="tight")
# The figure is saved to C:/QudiHiraAnalysis/20230101_NV1/odmr.pdf
Examples
NV-ODMR map
Extract a heatmap of ODMR splittings from a 2D raster NV-ODMR map.
# Extract ODMR measurements from the measurement folder
odmr_measurements = dh.load_measurements("2d_odmr_map")
odmr_measurements = dict(sorted(odmr_measurements.items()))
# Perform parallel (=num CPU cores) ODMR fitting
odmr_measurements = dh.fit_raster_odmr(odmr_measurements)
# Calculate 2D ODMR map from the fitted ODMR measurements
pixels = int(np.sqrt(len(odmr_measurements)))
image = np.zeros((pixels, pixels))
for idx, odmr in enumerate(odmr_measurements.values()):
row, col = odmr.xy_position
if len(odmr.fit_model.params) > 6:
# Calculate double Lorentzian splitting
image[row, col] = np.abs(odmr.fit_model.best_values["l1_center"]
- odmr.fit_model.best_values["l0_center"])
map = sns.heatmap(image, cbar_kws={"label": "Delta E (MHz)"})
# Save the figure to the figure folder specified earlier
dh.save_figures(filepath="2d_odmr_map", fig=map.get_figure(), only_jpg=True)
NV-PL map
Extract a heatmap of NV photo-luminescence from a 2D raster NV-PL map.
pixel_scanner_measurements = dh.load_measurements("PixelScanner")
fwd, bwd = pixel_scanner_measurements["20230101-0420-00"].data
# If size is known, it can be specified here
fwd.size["real"] = {"x": 1e-6, "y": 1e-6, "unit": "m"}
fig, ax = plt.subplots()
# Perform (optional) image corrections
fwd.filter_gaussian(sigma=0.5)
# Add scale bar, color bar and plot the data
img = fwd.show(cmap="inferno", ax=ax)
fwd.add_scale(length=1e-6, ax=ax, height=1)
cbar = fig.colorbar(img)
cbar.set_label("NV-PL (kcps)")
# Save the figure to the figure folder specified earlier
dh.save_figures(filepath="nv_pl_scan", fig=fig, only_jpg=True)
Nanonis AFM measurements
Extract a heatmap of AFM data from a 2D raster Nanonis AFM scan.
afm_measurements = dh.load_measurements("Scan", extension=".sxm", qudi=False)
afm = afm_measurements["20230101-0420-00"].data
# Print the channels available in the data
afm.list_channels()
topo = afm.get_channel("Z")
fig, ax = plt.subplots()
# Perform (optional) image corrections
topo.correct_lines()
topo.correct_plane()
topo.filter_lowpass(fft_radius=20)
topo.zero_min()
# Add scale bar, color bar and plot the data
img = topo.show(cmap="inferno", ax=ax)
topo.add_scale(length=1e-6, ax=ax, height=1, fontsize=10)
cbar = fig.colorbar(img)
cbar.set_label("Height (nm)")
dh.save_figures(filepath="afm_topo", fig=fig, only_jpg=True)
g(2) measurements (anti-bunching fit)
Extract a anti-bunching fit from a g(2) measurement.
autocorrelation_measurements = dh.load_measurements("Autocorrelation")
fig, ax = plt.subplots()
for autocorrelation in autocorrelation_measurements.values():
autocorrelation.data["Time (ns)"] = autocorrelation.data["Time (ps)"] * 1e-3
# Plot the data
sns.lineplot(data=autocorrelation.data, x="Time (ns)", y="g2(t) norm", ax=ax)
# Fit the data using the antibunching function
fit_x, fit_y, result = dh.fit(x="Time (ns)", y="g2(t) norm",
data=autocorrelation.data,
fit_function=dh.fit_function.antibunching)
# Plot the fit
sns.lineplot(x=fit_x, y=fit_y, ax=ax, color="C1")
# Save the figure to the figure folder specified earlier
dh.save_figures(filepath="autocorrelation_variation", fig=fig)
ODMR measurements (double Lorentzian fit)
Extract a double Lorentzian fit from an ODMR measurement.
odmr_measurements = dh.load_measurements("ODMR", pulsed=True)
fig, ax = plt.subplots()
for odmr in odmr_measurements.values():
sns.scatterplot(data=odmr.data, x="Controlled variable(Hz)", y="Signal", ax=ax)
fit_x, fit_y, result = dh.fit(x="Controlled variable(Hz)", y="Signal",
data=odmr.data,
fit_function=dh.fit_function.lorentziandouble)
sns.lineplot(x=fit_x, y=fit_y, ax=ax, color="C1")
dh.save_figures(filepath="odmr_variation", fig=fig)
Rabi measurements (sine exp. decay fit)
Extract a exponentially decaying sine fit from a Rabi measurement.
rabi_measurements = dh.load_measurements("Rabi", pulsed=True)
fig, ax = plt.subplots()
for rabi in rabi_measurements.values():
sns.scatterplot(data=rabi.data, x="Controlled variable(s)", y="Signal", ax=ax)
fit_x, fit_y, result = dh.fit(x="Controlled variable(s)", y="Signal",
data=rabi.data,
fit_function=dh.fit_function.sineexponentialdecay)
sns.lineplot(x=fit_x, y=fit_y, ax=ax, color="C1")
dh.save_figures(filepath="rabi_variation", fig=fig)
Temperature measurements
Extract temperature data from a Lakeshore temperature monitor.
temperature_measurements = dh.load_measurements("Temperature", qudi=False)
temperature = pd.concat([t.data for t in temperature_measurements.values()])
fig, ax = plt.subplots()
sns.lineplot(data=temperature, x="Time", y="Temperature", ax=ax)
dh.save_figures(filepath="temperature_monitoring", fig=fig)
Bruker MFM measurements
Extract a heatmap of MFM data from a 2D raster Bruker MFM map.
bruker_measurements = dh.load_measurements("mfm", extension=".001", qudi=False)
bruker_data = bruker_measurements["20230101-0420-00"].data
# Print the channels available in the data
bruker_data.list_channels()
mfm = bruker_data.get_channel("Phase", mfm=True)
fig, ax = plt.subplots()
# Perform (optional) image corrections
mfm.correct_plane()
mfm.zero_min()
# Add scale bar, color bar and plot the data
img = mfm.show(cmap="inferno", ax=ax)
mfm.add_scale(length=1, ax=ax, height=1, fontsize=10)
cbar = fig.colorbar(img)
cbar.set_label("MFM contrast (deg)")
dh.save_figures(filepath="MFM", fig=fig, only_jpg=True)
PYS data (pi3diamond compatibility)
pys_measurements = dh.load_measurements("ndmin", extension=".pys", qudi=False)
pys = pys_measurements[list(pys_measurements)[0]].data
fig, ax = plt.subplots()
sns.lineplot(x=pys["time_bins"], y=pys["counts"], ax=ax)
dh.save_figures(filepath="pys_measurement", fig=fig)
Expand source code
"""Analytics suite for qubit SPM using FPGA timetaggers
## Getting started
Start by creating an instance of the `DataHandler` class. Specify the location you
want to load data from (`data_folder`), the location you want to save figures to (
`figure_folder`) and (optionally) the name of the measurement folder (
`measurement_folder`). If a measurement folder is specified, its path will be
combined with the data folder path to form the full path to the measurement data.
```python
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from qudi_hira_analysis import DataHandler
dh = DataHandler(
data_folder=Path("C:/Data"), # Path to the data folder
figure_folder=Path("C:/QudiHiraAnalysis"), # Path to the figure folder
measurement_folder=Path("20230101_NV1") # Name of the measurement folder
)
# Output:
# qudi_hira_analysis.data_handler :: INFO :: Data folder path is C:/Data/20230101_NV1
# qudi_hira_analysis.data_handler :: INFO :: Figure folder path is C:/QudiHiraAnalysis/20230101_NV1
```
### Loading data
To load a specific set of measurements from the data folder, use the
`DataHandler.load_measurements()` method. The method takes a string as an argument
and searches for files with the string in the path. The files are lazy-loaded,
so the data is only loaded when it is needed. The method returns a dictionary,
where the keys are the timestamps of the measurements and the values are
`measurement_dataclass.MeasurementDataclass()` objects.
```python
# Search and lazy-load all pulsed measurements with "odmr" in the path into Dataclasses
odmr_measurements = dh.load_measurements("odmr", pulsed=True)
odmr = odmr_measurements["20230101-0420-00"]
```
>>> odmr
MeasurementDataclass(timestamp='2023-01-01 04:20:00', filename='odmr.dat')
>>> odmr.data
Controlled variable(Hz) Signal
0 2850.000000 1.001626
1 2850.816327 0.992160
2 2851.632653 0.975900
3 2852.448980 0.990770
4 2853.265306 0.994068
... ... ...
### Fitting data
To fit data, call the `DataHandler.fit()` method. This method accepts pandas
DataFrames, numpy arrays or pandas Series as inputs. To get the full list of
available fit routines, explore the `DataHandler.fit_function` attribute or call
`AnalysisLogic.get_all_fits()`.
The fit functions available are:
| Dimension | Fit |
|-----------|-------------------------------|
| 1D | decayexponential |
| | biexponential |
| | decayexponentialstretched |
| | gaussian |
| | gaussiandouble |
| | gaussianlinearoffset |
| | hyperbolicsaturation |
| | linear |
| | lorentzian |
| | lorentziandouble |
| | lorentziantriple |
| | sine |
| | sinedouble |
| | sinedoublewithexpdecay |
| | sinedoublewithtwoexpdecay |
| | sineexponentialdecay |
| | sinestretchedexponentialdecay |
| | sinetriple |
| | sinetriplewithexpdecay |
| | sinetriplewiththreeexpdecay |
| 2D | twoDgaussian |
```python
x_fit, y_fit, result = dh.fit(x="Controlled variable(Hz)", y="Signal",
fit_function=dh.fit_function.lorentziandouble,
data=odmr.data)
# Plot the data and the fit
ax = sns.scatterplot(x="Controlled variable(Hz)", y="Signal", data=odmr.data, label="Data")
sns.lineplot(x=x_fit, y=y_fit, ax=ax, label="Fit")
```
>>> print(result.fit_report())
[[Model]]
(((Model(amplitude_function, prefix='l0_') * Model(physical_lorentzian, prefix='l0_')) + Model(constant_function)) + (Model(amplitude_function, prefix='l1_') * Model(physical_lorentzian, prefix='l1_')))
[[Fit Statistics]]
# fitting method = leastsq
# function evals = 57
# data points = 50
# variables = 7
chi-square = 0.00365104
reduced chi-square = 8.4908e-05
Akaike info crit = -462.238252
Bayesian info crit = -448.854091
R-squared = 0.98267842
[[Variables]]
l0_amplitude: -0.22563916 +/- 0.00770009 (3.41%) (init = -0.1766346)
l0_center: 2866.58588 +/- 0.05208537 (0.00%) (init = 2866.327)
l0_sigma: 1.51496671 +/- 0.08457482 (5.58%) (init = 2.834718)
offset: 1.00213733 +/- 0.00189341 (0.19%) (init = 0.9965666)
l1_amplitude: -0.21380290 +/- 0.00772333 (3.61%) (init = -0.1676134)
l1_center: 2873.49077 +/- 0.05478295 (0.00%) (init = 2872.857)
l1_sigma: 1.49704419 +/- 0.08813686 (5.89%) (init = 2.987285)
l0_fwhm: 3.02993343 +/- 0.16914964 (5.58%) == '2*l0_sigma'
l0_contrast: -22.5157920 +/- 0.76625384 (3.40%) == '(l0_amplitude/offset)*100'
l1_fwhm: 2.99408834 +/- 0.17627372 (5.89%) == '2*l1_sigma'
l1_contrast: -21.3346912 +/- 0.76857477 (3.60%) == '(l1_amplitude/offset)*100'
[[Correlations]] (unreported correlations are < 0.100)
C(l0_amplitude, l0_sigma) = +0.6038
C(l1_amplitude, l1_sigma) = +0.6020
C(l0_sigma, offset) = +0.3301
C(offset, l1_sigma) = +0.3250
C(l0_sigma, l1_sigma) = -0.1588
C(l0_center, l1_sigma) = -0.1531
C(l0_sigma, l1_center) = +0.1531
### Saving data
To save figures, call the `DataHandler.save_figures()` method. By default,
the figures are saved as JPG, PDF, PNG and SVG. This can be changed by setting the
`only_jpg` or `only_pdf` arguments to `True`. All other keyword arguments are passed
to the `matplotlib.pyplot.savefig()` function.
```python
# Save the figure to the figure folder specified earlier
dh.save_figures(filepath=Path("odmr"), fig=ax.get_figure(),
only_pdf=True, bbox_inches="tight")
# The figure is saved to C:/QudiHiraAnalysis/20230101_NV1/odmr.pdf
```
## Examples
### NV-ODMR map
Extract a heatmap of ODMR splittings from a 2D raster NV-ODMR map.
```python
# Extract ODMR measurements from the measurement folder
odmr_measurements = dh.load_measurements("2d_odmr_map")
odmr_measurements = dict(sorted(odmr_measurements.items()))
# Perform parallel (=num CPU cores) ODMR fitting
odmr_measurements = dh.fit_raster_odmr(odmr_measurements)
# Calculate 2D ODMR map from the fitted ODMR measurements
pixels = int(np.sqrt(len(odmr_measurements)))
image = np.zeros((pixels, pixels))
for idx, odmr in enumerate(odmr_measurements.values()):
row, col = odmr.xy_position
if len(odmr.fit_model.params) > 6:
# Calculate double Lorentzian splitting
image[row, col] = np.abs(odmr.fit_model.best_values["l1_center"]
- odmr.fit_model.best_values["l0_center"])
map = sns.heatmap(image, cbar_kws={"label": "Delta E (MHz)"})
# Save the figure to the figure folder specified earlier
dh.save_figures(filepath="2d_odmr_map", fig=map.get_figure(), only_jpg=True)
```
### NV-PL map
Extract a heatmap of NV photo-luminescence from a 2D raster NV-PL map.
```python
pixel_scanner_measurements = dh.load_measurements("PixelScanner")
fwd, bwd = pixel_scanner_measurements["20230101-0420-00"].data
# If size is known, it can be specified here
fwd.size["real"] = {"x": 1e-6, "y": 1e-6, "unit": "m"}
fig, ax = plt.subplots()
# Perform (optional) image corrections
fwd.filter_gaussian(sigma=0.5)
# Add scale bar, color bar and plot the data
img = fwd.show(cmap="inferno", ax=ax)
fwd.add_scale(length=1e-6, ax=ax, height=1)
cbar = fig.colorbar(img)
cbar.set_label("NV-PL (kcps)")
# Save the figure to the figure folder specified earlier
dh.save_figures(filepath="nv_pl_scan", fig=fig, only_jpg=True)
```
### Nanonis AFM measurements
Extract a heatmap of AFM data from a 2D raster Nanonis AFM scan.
```python
afm_measurements = dh.load_measurements("Scan", extension=".sxm", qudi=False)
afm = afm_measurements["20230101-0420-00"].data
# Print the channels available in the data
afm.list_channels()
topo = afm.get_channel("Z")
fig, ax = plt.subplots()
# Perform (optional) image corrections
topo.correct_lines()
topo.correct_plane()
topo.filter_lowpass(fft_radius=20)
topo.zero_min()
# Add scale bar, color bar and plot the data
img = topo.show(cmap="inferno", ax=ax)
topo.add_scale(length=1e-6, ax=ax, height=1, fontsize=10)
cbar = fig.colorbar(img)
cbar.set_label("Height (nm)")
dh.save_figures(filepath="afm_topo", fig=fig, only_jpg=True)
```
### g(2) measurements (anti-bunching fit)
Extract a anti-bunching fit from a g(2) measurement.
```python
autocorrelation_measurements = dh.load_measurements("Autocorrelation")
fig, ax = plt.subplots()
for autocorrelation in autocorrelation_measurements.values():
autocorrelation.data["Time (ns)"] = autocorrelation.data["Time (ps)"] * 1e-3
# Plot the data
sns.lineplot(data=autocorrelation.data, x="Time (ns)", y="g2(t) norm", ax=ax)
# Fit the data using the antibunching function
fit_x, fit_y, result = dh.fit(x="Time (ns)", y="g2(t) norm",
data=autocorrelation.data,
fit_function=dh.fit_function.antibunching)
# Plot the fit
sns.lineplot(x=fit_x, y=fit_y, ax=ax, color="C1")
# Save the figure to the figure folder specified earlier
dh.save_figures(filepath="autocorrelation_variation", fig=fig)
```
### ODMR measurements (double Lorentzian fit)
Extract a double Lorentzian fit from an ODMR measurement.
```python
odmr_measurements = dh.load_measurements("ODMR", pulsed=True)
fig, ax = plt.subplots()
for odmr in odmr_measurements.values():
sns.scatterplot(data=odmr.data, x="Controlled variable(Hz)", y="Signal", ax=ax)
fit_x, fit_y, result = dh.fit(x="Controlled variable(Hz)", y="Signal",
data=odmr.data,
fit_function=dh.fit_function.lorentziandouble)
sns.lineplot(x=fit_x, y=fit_y, ax=ax, color="C1")
dh.save_figures(filepath="odmr_variation", fig=fig)
```
### Rabi measurements (sine exp. decay fit)
Extract a exponentially decaying sine fit from a Rabi measurement.
```python
rabi_measurements = dh.load_measurements("Rabi", pulsed=True)
fig, ax = plt.subplots()
for rabi in rabi_measurements.values():
sns.scatterplot(data=rabi.data, x="Controlled variable(s)", y="Signal", ax=ax)
fit_x, fit_y, result = dh.fit(x="Controlled variable(s)", y="Signal",
data=rabi.data,
fit_function=dh.fit_function.sineexponentialdecay)
sns.lineplot(x=fit_x, y=fit_y, ax=ax, color="C1")
dh.save_figures(filepath="rabi_variation", fig=fig)
```
### Temperature measurements
Extract temperature data from a Lakeshore temperature monitor.
```python
temperature_measurements = dh.load_measurements("Temperature", qudi=False)
temperature = pd.concat([t.data for t in temperature_measurements.values()])
fig, ax = plt.subplots()
sns.lineplot(data=temperature, x="Time", y="Temperature", ax=ax)
dh.save_figures(filepath="temperature_monitoring", fig=fig)
```
### Bruker MFM measurements
Extract a heatmap of MFM data from a 2D raster Bruker MFM map.
```python
bruker_measurements = dh.load_measurements("mfm", extension=".001", qudi=False)
bruker_data = bruker_measurements["20230101-0420-00"].data
# Print the channels available in the data
bruker_data.list_channels()
mfm = bruker_data.get_channel("Phase", mfm=True)
fig, ax = plt.subplots()
# Perform (optional) image corrections
mfm.correct_plane()
mfm.zero_min()
# Add scale bar, color bar and plot the data
img = mfm.show(cmap="inferno", ax=ax)
mfm.add_scale(length=1, ax=ax, height=1, fontsize=10)
cbar = fig.colorbar(img)
cbar.set_label("MFM contrast (deg)")
dh.save_figures(filepath="MFM", fig=fig, only_jpg=True)
```
### PYS data (pi3diamond compatibility)
```python
pys_measurements = dh.load_measurements("ndmin", extension=".pys", qudi=False)
pys = pys_measurements[list(pys_measurements)[0]].data
fig, ax = plt.subplots()
sns.lineplot(x=pys["time_bins"], y=pys["counts"], ax=ax)
dh.save_figures(filepath="pys_measurement", fig=fig)
```
"""
from .analysis_logic import AnalysisLogic, FitMethodsAndEstimators
from .data_handler import DataHandler
from .io_handler import IOHandler
__all__ = ["DataHandler", "IOHandler", "AnalysisLogic", "FitMethodsAndEstimators"]
Sub-modules
qudi_hira_analysis.analysis_logic
qudi_hira_analysis.data_handler
qudi_hira_analysis.helper_functions
qudi_hira_analysis.io_handler
qudi_hira_analysis.measurement_dataclass
Classes
class AnalysisLogic
-
Class for performing analysis on measurement data
Expand source code
class AnalysisLogic(FitLogic): """ Class for performing analysis on measurement data """ fit_function = FitMethodsAndEstimators def __init__(self): super().__init__() self.log = logging.getLogger(__name__) def _perform_fit( self, x: np.ndarray, y: np.ndarray, fit_function: str, estimator: str, parameters: list[Parameter] | None = None, dims: str = "1d") -> tuple[np.ndarray, np.ndarray, ModelResult]: fit = { dims: {'default': {'fit_function': fit_function, 'estimator': estimator}}} user_fit = self.validate_load_fits(fit) if parameters: user_fit[dims]["default"]["parameters"].add_many(*parameters) use_settings = {} for key in user_fit[dims]["default"]["parameters"]: if parameters: if key in [p.name for p in parameters]: use_settings[key] = True else: use_settings[key] = False else: use_settings[key] = False user_fit[dims]["default"]["use_settings"] = use_settings fc = self.make_fit_container("test", dims) fc.set_fit_functions(user_fit[dims]) fc.set_current_fit("default") fit_x, fit_y, result = fc.do_fit(x, y) return fit_x, fit_y, result def fit( self, x: str | np.ndarray | pd.Series, y: str | np.ndarray | pd.Series, fit_function: FitMethodsAndEstimators, data: pd.DataFrame = None, parameters: list[Parameter] | None = None ) -> tuple[np.ndarray, np.ndarray, ModelResult]: """ Args: x: x data, can be string, numpy array or pandas Series y: y data, can be string, numpy array or pandas Series fit_function: fit function to use data: pandas DataFrame containing x and y data, if None x and y must be numpy arrays or pandas Series parameters: list of parameters to use in fit (optional) Returns: Fit x data, fit y data and lmfit ModelResult """ if "twoD" in fit_function[0]: dims: str = "2d" else: dims: str = "1d" if data is None: if isinstance(x, (pd.Series, pd.Index)): x: np.ndarray = x.to_numpy() if isinstance(y, pd.Series): y: np.ndarray = y.to_numpy() elif isinstance(data, pd.DataFrame): x: np.ndarray = data[x].to_numpy() y: np.ndarray = data[y].to_numpy() else: raise TypeError("Data must be a pandas DataFrame or None") return self._perform_fit( x=x, y=y, fit_function=fit_function[0], estimator=fit_function[1], parameters=parameters, dims=dims ) def get_all_fits(self) -> tuple[list, list]: """Get all available fits Returns: Tuple with list of 1d and 2d fits """ one_d_fits: list = list(self.fit_list['1d'].keys()) two_d_fits: list = list(self.fit_list['2d'].keys()) self.log.info(f"1d fits: {one_d_fits}\n2d fits: {two_d_fits}") return one_d_fits, two_d_fits @staticmethod def analyze_mean( laser_data: np.ndarray, signal_start: float = 100e-9, signal_end: float = 300e-9, bin_width: float = 1e-9 ) -> tuple[np.ndarray, np.ndarray]: """ Calculate the mean of the signal window. Args: laser_data: 2D array of laser data signal_start: start of the signal window in seconds signal_end: end of the signal window in seconds bin_width: width of a bin in seconds Returns: Mean of the signal window and measurement error """ # Get number of lasers num_of_lasers = laser_data.shape[0] if not isinstance(bin_width, float): return np.zeros(num_of_lasers), np.zeros(num_of_lasers) # Convert the times in seconds to bins (i.e. array indices) signal_start_bin = round(signal_start / bin_width) signal_end_bin = round(signal_end / bin_width) # initialize data arrays for signal and measurement error signal_data = np.empty(num_of_lasers, dtype=float) error_data = np.empty(num_of_lasers, dtype=float) # loop over all laser pulses and analyze them for ii, laser_arr in enumerate(laser_data): # calculate the mean of the data in the signal window signal = laser_arr[signal_start_bin:signal_end_bin].mean() signal_sum = laser_arr[signal_start_bin:signal_end_bin].sum() signal_error = np.sqrt(signal_sum) / (signal_end_bin - signal_start_bin) # Avoid numpy C type variables overflow and NaN values if signal < 0 or signal != signal: signal_data[ii] = 0.0 error_data[ii] = 0.0 else: signal_data[ii] = signal error_data[ii] = signal_error return signal_data, error_data @staticmethod def analyze_mean_reference( laser_data: np.ndarray, signal_start: float = 100e-9, signal_end: float = 300e-9, norm_start: float = 1000e-9, norm_end: float = 2000e-9, bin_width: float = 1e-9) -> tuple[np.ndarray, np.ndarray]: """ Subtracts the mean of the signal window from the mean of the reference window. Args: laser_data: 2D array of laser data signal_start: start of the signal window in seconds signal_end: end of the signal window in seconds norm_start: start of the reference window in seconds norm_end: end of the reference window in seconds bin_width: width of a bin in seconds Returns: Referenced mean of the signal window and measurement error """ # Get number of lasers num_of_lasers = laser_data.shape[0] if not isinstance(bin_width, float): return np.zeros(num_of_lasers), np.zeros(num_of_lasers) # Convert the times in seconds to bins (i.e. array indices) signal_start_bin = round(signal_start / bin_width) signal_end_bin = round(signal_end / bin_width) norm_start_bin = round(norm_start / bin_width) norm_end_bin = round(norm_end / bin_width) # initialize data arrays for signal and measurement error signal_data = np.empty(num_of_lasers, dtype=float) error_data = np.empty(num_of_lasers, dtype=float) # loop over all laser pulses and analyze them for ii, laser_arr in enumerate(laser_data): # calculate the sum and mean of the data in the normalization window counts = laser_arr[norm_start_bin:norm_end_bin] reference_sum = np.sum(counts) reference_mean = (reference_sum / len(counts)) if len(counts) != 0 else 0.0 # calculate the sum and mean of the data in the signal window counts = laser_arr[signal_start_bin:signal_end_bin] signal_sum = np.sum(counts) signal_mean = (signal_sum / len(counts)) if len(counts) != 0 else 0.0 signal_data[ii] = signal_mean - reference_mean # calculate with respect to gaussian error 'evolution' error_data[ii] = signal_data[ii] * np.sqrt( 1 / abs(signal_sum) + 1 / abs(reference_sum)) return signal_data, error_data @staticmethod def analyze_mean_norm( laser_data: np.ndarray, signal_start: float = 100e-9, signal_end: float = 300e-9, norm_start: float = 1000e-9, norm_end=2000e-9, bin_width: float = 1e-9 ) -> tuple[np.ndarray, np.ndarray]: """ Divides the mean of the signal window from the mean of the reference window. Args: laser_data: 2D array of laser data signal_start: start of the signal window in seconds signal_end: end of the signal window in seconds norm_start: start of the reference window in seconds norm_end: end of the reference window in seconds bin_width: width of a bin in seconds Returns: Normalized mean of the signal window and measurement error """ # Get number of lasers num_of_lasers = laser_data.shape[0] if not isinstance(bin_width, float): return np.zeros(num_of_lasers), np.zeros(num_of_lasers) # Convert the times in seconds to bins (i.e. array indices) signal_start_bin = round(signal_start / bin_width) signal_end_bin = round(signal_end / bin_width) norm_start_bin = round(norm_start / bin_width) norm_end_bin = round(norm_end / bin_width) # initialize data arrays for signal and measurement error signal_data = np.empty(num_of_lasers, dtype=float) error_data = np.empty(num_of_lasers, dtype=float) # loop over all laser pulses and analyze them for ii, laser_arr in enumerate(laser_data): # calculate the sum and mean of the data in the normalization window counts = laser_arr[norm_start_bin:norm_end_bin] reference_sum = np.sum(counts) reference_mean = (reference_sum / len(counts)) if len(counts) != 0 else 0.0 # calculate the sum and mean of the data in the signal window counts = laser_arr[signal_start_bin:signal_end_bin] signal_sum = np.sum(counts) signal_mean = (signal_sum / len(counts)) if len(counts) != 0 else 0.0 # Calculate normalized signal while avoiding division by zero if reference_mean > 0 and signal_mean >= 0: signal_data[ii] = signal_mean / reference_mean else: signal_data[ii] = 0.0 # Calculate measurement error while avoiding division by zero if reference_sum > 0 and signal_sum > 0: # calculate with respect to gaussian error 'evolution' error_data[ii] = signal_data[ii] * np.sqrt( 1 / signal_sum + 1 / reference_sum) else: error_data[ii] = 0.0 return signal_data, error_data def optimize_raster_odmr_params( self, measurements: dict[str, MeasurementDataclass], num_samples: int = 10, num_params: int = 3, ) -> tuple[float, tuple[float, float, float]]: """ This method optimizes the hyperparameters of the ODMR analysis. It does so by randomly sampling a subset of the measurements and then optimizing the hyperparameters for them. Args: measurements: A dictionary of measurements to optimize the hyperparameters. num_params: The number of parameters to optimize. num_samples: The number of measurements to sample. Returns: The highest minimum R2 value and the optimized hyperparameters. """ r2_threshs: np.ndarray = np.around( np.linspace(start=0.9, stop=0.99, num=num_params), decimals=2 ) thresh_fracs: np.ndarray = np.around( np.linspace(start=0.5, stop=0.9, num=num_params), decimals=1 ) sigma_thresh_fracs: np.ndarray = np.around( np.linspace(start=0.1, stop=0.2, num=num_params), decimals=1 ) odmr_sample: dict = {} for k, v in random.sample(sorted(measurements.items()), k=num_samples): odmr_sample[k] = v highest_min_r2: float = 0 optimal_params: tuple[float, float, float] = (0, 0, 0) for r2_thresh, thresh_frac, sigma_thresh_frac in product( r2_threshs, thresh_fracs, sigma_thresh_fracs): odmr_sample = self.fit_raster_odmr( odmr_sample, r2_thresh=r2_thresh, thresh_frac=thresh_frac, sigma_thresh_frac=sigma_thresh_frac, min_thresh=0.01, progress_bar=False ) r2s: np.ndarray = np.zeros(len(odmr_sample)) for idx, odmr in enumerate(odmr_sample.values()): r2s[idx] = odmr.fit_model.rsquared min_r2: float = np.min(r2s) if highest_min_r2 < min_r2: highest_min_r2 = min_r2 optimal_params = (r2_thresh, thresh_frac, sigma_thresh_frac) return highest_min_r2, optimal_params @staticmethod def _lorentzian_fitting( x: np.ndarray, y: np.ndarray, model1: Model, model2: Model, params1: Parameters, params2: Parameters, r2_thresh: float ) -> ModelResult: """ Make Lorentzian fitting for single and double Lorentzian model """ res1 = rof.make_lorentzian_fit(x, y, model1, params1) if res1.rsquared < r2_thresh: return rof.make_lorentziandouble_fit(x, y, model2, params2) return res1 def fit_raster_odmr( self, odmr_measurements: dict[str, MeasurementDataclass], r2_thresh: float = 0.95, thresh_frac: float = 0.5, sigma_thresh_frac: float = 0.15, min_thresh: float = 0.01, extract_pixel_from_filename: bool = True, progress_bar: bool = True ) -> dict[str, MeasurementDataclass]: """ Fit a list of ODMR data to single and double Lorentzian functions Args: odmr_measurements: Dict of ODMR data in MeasurementDataclasses r2_thresh: R^2 Threshold below which a double lorentzian is fitted instead of a single lorentzian thresh_frac: Threshold fraction for the peak finding min_thresh: Minimum threshold for the peak finding sigma_thresh_frac: Change in threshold fraction for the peak finding extract_pixel_from_filename: Extract `(row, col)` (in this format) from filename progress_bar: Show progress bar Returns: Dict of ODMR MeasurementDataclass with fit, fit model and pixels attributes set """ model1, base_params1 = rof.make_lorentzian_model() model2, base_params2 = rof.make_lorentziandouble_model() # Generate arguments for the parallel fitting args = [] for odmr in tqdm(odmr_measurements.values(), disable=not progress_bar): x = odmr.data["Freq(MHz)"].to_numpy() y = odmr.data["Counts"].to_numpy() _, params1 = rof.estimate_lorentzian_dip(x, y, base_params1) _, params2 = rof.estimate_lorentziandouble_dip( x, y, base_params2, thresh_frac, min_thresh, sigma_thresh_frac ) args.append((x, y, model1, model2, params1, params2, r2_thresh)) # Parallel fitting model_results = Parallel(n_jobs=cpu_count())( delayed(self._lorentzian_fitting)( x, y, model1, model2, params1, params2, r2_thresh) for x, y, model1, model2, params1, params2, r2_thresh in tqdm(args, disable=not progress_bar) ) x = next(iter(odmr_measurements.values())).data["Freq(MHz)"].to_numpy() x_fit = np.linspace(start=x[0], stop=x[-1], num=int(len(x) * 2)) for odmr, res in zip(odmr_measurements.values(), model_results): if len(res.params) == 6: # Evaluate a single Lorentzian y_fit = model1.eval(x=x_fit, params=res.params) else: # Evaluate a double Lorentzian y_fit = model2.eval(x=x_fit, params=res.params) # Plug results into the DataClass odmr.fit_model = res odmr.fit_data = pd.DataFrame(np.vstack((x_fit, y_fit)).T, columns=["x_fit", "y_fit"]) if extract_pixel_from_filename: # Extract the pixel with regex from the filename row, col = map( int, re.findall(r'(?<=\().*?(?=\))', odmr.filename)[0].split(",") ) odmr.xy_position = (row, col) return odmr_measurements @staticmethod def average_raster_odmr_pixels(orig_image: np.ndarray) -> np.ndarray: """ Average a NaN pixel to its surrounding pixels. Args: orig_image: Image with NaN pixels Returns: Image with NaN pixels replaced by the average of its surrounding pixels """ image: np.ndarray = orig_image.copy() for row, col in np.argwhere(np.isnan(image)): if row == 0: pixel_avg = np.nanmean(image[row + 1:row + 2, col - 1:col + 2]) elif row == image.shape[0] - 1: pixel_avg = np.nanmean(image[row - 1:row, col - 1:col + 2]) elif col == 0: pixel_avg = np.nanmean(image[row - 1:row + 2, col + 1:col + 2]) elif col == image.shape[1] - 1: pixel_avg = np.nanmean(image[row - 1:row + 2, col - 1:col]) else: pixel_avg = np.nanmean(image[row - 1:row + 2, col - 1:col + 2]) image[row, col] = pixel_avg return image
Ancestors
- qudi_hira_analysis._qudi_fit_logic.FitLogic
Subclasses
Class variables
var fit_function
-
Class for storing fit methods and estimators. Fit methods are stored as tuples of (method, estimator) where method is the name of the fit method and estimator is the name of the estimator.
The fit functions available are:
Dimension Fit 1D decayexponential biexponential decayexponentialstretched gaussian gaussiandouble gaussianlinearoffset hyperbolicsaturation linear lorentzian lorentziandouble lorentziantriple sine sinedouble sinedoublewithexpdecay sinedoublewithtwoexpdecay sineexponentialdecay sinestretchedexponentialdecay sinetriple sinetriplewithexpdecay sinetriplewiththreeexpdecay 2D twoDgaussian
Static methods
def analyze_mean(laser_data: np.ndarray, signal_start: float = 1e-07, signal_end: float = 3e-07, bin_width: float = 1e-09) ‑> tuple[numpy.ndarray, numpy.ndarray]
-
Calculate the mean of the signal window.
Args
laser_data
- 2D array of laser data
signal_start
- start of the signal window in seconds
signal_end
- end of the signal window in seconds
bin_width
- width of a bin in seconds
Returns
Mean of the signal window and measurement error
Expand source code
@staticmethod def analyze_mean( laser_data: np.ndarray, signal_start: float = 100e-9, signal_end: float = 300e-9, bin_width: float = 1e-9 ) -> tuple[np.ndarray, np.ndarray]: """ Calculate the mean of the signal window. Args: laser_data: 2D array of laser data signal_start: start of the signal window in seconds signal_end: end of the signal window in seconds bin_width: width of a bin in seconds Returns: Mean of the signal window and measurement error """ # Get number of lasers num_of_lasers = laser_data.shape[0] if not isinstance(bin_width, float): return np.zeros(num_of_lasers), np.zeros(num_of_lasers) # Convert the times in seconds to bins (i.e. array indices) signal_start_bin = round(signal_start / bin_width) signal_end_bin = round(signal_end / bin_width) # initialize data arrays for signal and measurement error signal_data = np.empty(num_of_lasers, dtype=float) error_data = np.empty(num_of_lasers, dtype=float) # loop over all laser pulses and analyze them for ii, laser_arr in enumerate(laser_data): # calculate the mean of the data in the signal window signal = laser_arr[signal_start_bin:signal_end_bin].mean() signal_sum = laser_arr[signal_start_bin:signal_end_bin].sum() signal_error = np.sqrt(signal_sum) / (signal_end_bin - signal_start_bin) # Avoid numpy C type variables overflow and NaN values if signal < 0 or signal != signal: signal_data[ii] = 0.0 error_data[ii] = 0.0 else: signal_data[ii] = signal error_data[ii] = signal_error return signal_data, error_data
def analyze_mean_norm(laser_data: np.ndarray, signal_start: float = 1e-07, signal_end: float = 3e-07, norm_start: float = 1e-06, norm_end=2e-06, bin_width: float = 1e-09) ‑> tuple[numpy.ndarray, numpy.ndarray]
-
Divides the mean of the signal window from the mean of the reference window.
Args
laser_data
- 2D array of laser data
signal_start
- start of the signal window in seconds
signal_end
- end of the signal window in seconds
norm_start
- start of the reference window in seconds
norm_end
- end of the reference window in seconds
bin_width
- width of a bin in seconds
Returns
Normalized mean of the signal window and measurement error
Expand source code
@staticmethod def analyze_mean_norm( laser_data: np.ndarray, signal_start: float = 100e-9, signal_end: float = 300e-9, norm_start: float = 1000e-9, norm_end=2000e-9, bin_width: float = 1e-9 ) -> tuple[np.ndarray, np.ndarray]: """ Divides the mean of the signal window from the mean of the reference window. Args: laser_data: 2D array of laser data signal_start: start of the signal window in seconds signal_end: end of the signal window in seconds norm_start: start of the reference window in seconds norm_end: end of the reference window in seconds bin_width: width of a bin in seconds Returns: Normalized mean of the signal window and measurement error """ # Get number of lasers num_of_lasers = laser_data.shape[0] if not isinstance(bin_width, float): return np.zeros(num_of_lasers), np.zeros(num_of_lasers) # Convert the times in seconds to bins (i.e. array indices) signal_start_bin = round(signal_start / bin_width) signal_end_bin = round(signal_end / bin_width) norm_start_bin = round(norm_start / bin_width) norm_end_bin = round(norm_end / bin_width) # initialize data arrays for signal and measurement error signal_data = np.empty(num_of_lasers, dtype=float) error_data = np.empty(num_of_lasers, dtype=float) # loop over all laser pulses and analyze them for ii, laser_arr in enumerate(laser_data): # calculate the sum and mean of the data in the normalization window counts = laser_arr[norm_start_bin:norm_end_bin] reference_sum = np.sum(counts) reference_mean = (reference_sum / len(counts)) if len(counts) != 0 else 0.0 # calculate the sum and mean of the data in the signal window counts = laser_arr[signal_start_bin:signal_end_bin] signal_sum = np.sum(counts) signal_mean = (signal_sum / len(counts)) if len(counts) != 0 else 0.0 # Calculate normalized signal while avoiding division by zero if reference_mean > 0 and signal_mean >= 0: signal_data[ii] = signal_mean / reference_mean else: signal_data[ii] = 0.0 # Calculate measurement error while avoiding division by zero if reference_sum > 0 and signal_sum > 0: # calculate with respect to gaussian error 'evolution' error_data[ii] = signal_data[ii] * np.sqrt( 1 / signal_sum + 1 / reference_sum) else: error_data[ii] = 0.0 return signal_data, error_data
def analyze_mean_reference(laser_data: np.ndarray, signal_start: float = 1e-07, signal_end: float = 3e-07, norm_start: float = 1e-06, norm_end: float = 2e-06, bin_width: float = 1e-09) ‑> tuple[numpy.ndarray, numpy.ndarray]
-
Subtracts the mean of the signal window from the mean of the reference window.
Args
laser_data
- 2D array of laser data
signal_start
- start of the signal window in seconds
signal_end
- end of the signal window in seconds
norm_start
- start of the reference window in seconds
norm_end
- end of the reference window in seconds
bin_width
- width of a bin in seconds
Returns
Referenced mean of the signal window and measurement error
Expand source code
@staticmethod def analyze_mean_reference( laser_data: np.ndarray, signal_start: float = 100e-9, signal_end: float = 300e-9, norm_start: float = 1000e-9, norm_end: float = 2000e-9, bin_width: float = 1e-9) -> tuple[np.ndarray, np.ndarray]: """ Subtracts the mean of the signal window from the mean of the reference window. Args: laser_data: 2D array of laser data signal_start: start of the signal window in seconds signal_end: end of the signal window in seconds norm_start: start of the reference window in seconds norm_end: end of the reference window in seconds bin_width: width of a bin in seconds Returns: Referenced mean of the signal window and measurement error """ # Get number of lasers num_of_lasers = laser_data.shape[0] if not isinstance(bin_width, float): return np.zeros(num_of_lasers), np.zeros(num_of_lasers) # Convert the times in seconds to bins (i.e. array indices) signal_start_bin = round(signal_start / bin_width) signal_end_bin = round(signal_end / bin_width) norm_start_bin = round(norm_start / bin_width) norm_end_bin = round(norm_end / bin_width) # initialize data arrays for signal and measurement error signal_data = np.empty(num_of_lasers, dtype=float) error_data = np.empty(num_of_lasers, dtype=float) # loop over all laser pulses and analyze them for ii, laser_arr in enumerate(laser_data): # calculate the sum and mean of the data in the normalization window counts = laser_arr[norm_start_bin:norm_end_bin] reference_sum = np.sum(counts) reference_mean = (reference_sum / len(counts)) if len(counts) != 0 else 0.0 # calculate the sum and mean of the data in the signal window counts = laser_arr[signal_start_bin:signal_end_bin] signal_sum = np.sum(counts) signal_mean = (signal_sum / len(counts)) if len(counts) != 0 else 0.0 signal_data[ii] = signal_mean - reference_mean # calculate with respect to gaussian error 'evolution' error_data[ii] = signal_data[ii] * np.sqrt( 1 / abs(signal_sum) + 1 / abs(reference_sum)) return signal_data, error_data
def average_raster_odmr_pixels(orig_image: np.ndarray) ‑> numpy.ndarray
-
Average a NaN pixel to its surrounding pixels.
Args
orig_image
- Image with NaN pixels
Returns
Image with NaN pixels replaced by the average of its surrounding pixels
Expand source code
@staticmethod def average_raster_odmr_pixels(orig_image: np.ndarray) -> np.ndarray: """ Average a NaN pixel to its surrounding pixels. Args: orig_image: Image with NaN pixels Returns: Image with NaN pixels replaced by the average of its surrounding pixels """ image: np.ndarray = orig_image.copy() for row, col in np.argwhere(np.isnan(image)): if row == 0: pixel_avg = np.nanmean(image[row + 1:row + 2, col - 1:col + 2]) elif row == image.shape[0] - 1: pixel_avg = np.nanmean(image[row - 1:row, col - 1:col + 2]) elif col == 0: pixel_avg = np.nanmean(image[row - 1:row + 2, col + 1:col + 2]) elif col == image.shape[1] - 1: pixel_avg = np.nanmean(image[row - 1:row + 2, col - 1:col]) else: pixel_avg = np.nanmean(image[row - 1:row + 2, col - 1:col + 2]) image[row, col] = pixel_avg return image
Methods
def fit(self, x: str | np.ndarray | pd.Series, y: str | np.ndarray | pd.Series, fit_function: FitMethodsAndEstimators, data: pd.DataFrame = None, parameters: list[Parameter] | None = None) ‑> tuple[np.ndarray, np.ndarray, ModelResult]
-
Args
x
- x data, can be string, numpy array or pandas Series
y
- y data, can be string, numpy array or pandas Series
fit_function
- fit function to use
data
- pandas DataFrame containing x and y data, if None x and y must be
- numpy arrays or pandas Series
parameters
- list of parameters to use in fit (optional)
Returns
Fit x data, fit y data and lmfit ModelResult
Expand source code
def fit( self, x: str | np.ndarray | pd.Series, y: str | np.ndarray | pd.Series, fit_function: FitMethodsAndEstimators, data: pd.DataFrame = None, parameters: list[Parameter] | None = None ) -> tuple[np.ndarray, np.ndarray, ModelResult]: """ Args: x: x data, can be string, numpy array or pandas Series y: y data, can be string, numpy array or pandas Series fit_function: fit function to use data: pandas DataFrame containing x and y data, if None x and y must be numpy arrays or pandas Series parameters: list of parameters to use in fit (optional) Returns: Fit x data, fit y data and lmfit ModelResult """ if "twoD" in fit_function[0]: dims: str = "2d" else: dims: str = "1d" if data is None: if isinstance(x, (pd.Series, pd.Index)): x: np.ndarray = x.to_numpy() if isinstance(y, pd.Series): y: np.ndarray = y.to_numpy() elif isinstance(data, pd.DataFrame): x: np.ndarray = data[x].to_numpy() y: np.ndarray = data[y].to_numpy() else: raise TypeError("Data must be a pandas DataFrame or None") return self._perform_fit( x=x, y=y, fit_function=fit_function[0], estimator=fit_function[1], parameters=parameters, dims=dims )
def fit_raster_odmr(self, odmr_measurements: dict[str, MeasurementDataclass], r2_thresh: float = 0.95, thresh_frac: float = 0.5, sigma_thresh_frac: float = 0.15, min_thresh: float = 0.01, extract_pixel_from_filename: bool = True, progress_bar: bool = True) ‑> dict[str, MeasurementDataclass]
-
Fit a list of ODMR data to single and double Lorentzian functions
Args
odmr_measurements
- Dict of ODMR data in MeasurementDataclasses
r2_thresh
- R^2 Threshold below which a double lorentzian is fitted instead of a single lorentzian
thresh_frac
- Threshold fraction for the peak finding
min_thresh
- Minimum threshold for the peak finding
sigma_thresh_frac
- Change in threshold fraction for the peak finding
extract_pixel_from_filename
- Extract
(row, col)
(in this format) from filename progress_bar
- Show progress bar
Returns
Dict of ODMR MeasurementDataclass with fit, fit model and pixels attributes set
Expand source code
def fit_raster_odmr( self, odmr_measurements: dict[str, MeasurementDataclass], r2_thresh: float = 0.95, thresh_frac: float = 0.5, sigma_thresh_frac: float = 0.15, min_thresh: float = 0.01, extract_pixel_from_filename: bool = True, progress_bar: bool = True ) -> dict[str, MeasurementDataclass]: """ Fit a list of ODMR data to single and double Lorentzian functions Args: odmr_measurements: Dict of ODMR data in MeasurementDataclasses r2_thresh: R^2 Threshold below which a double lorentzian is fitted instead of a single lorentzian thresh_frac: Threshold fraction for the peak finding min_thresh: Minimum threshold for the peak finding sigma_thresh_frac: Change in threshold fraction for the peak finding extract_pixel_from_filename: Extract `(row, col)` (in this format) from filename progress_bar: Show progress bar Returns: Dict of ODMR MeasurementDataclass with fit, fit model and pixels attributes set """ model1, base_params1 = rof.make_lorentzian_model() model2, base_params2 = rof.make_lorentziandouble_model() # Generate arguments for the parallel fitting args = [] for odmr in tqdm(odmr_measurements.values(), disable=not progress_bar): x = odmr.data["Freq(MHz)"].to_numpy() y = odmr.data["Counts"].to_numpy() _, params1 = rof.estimate_lorentzian_dip(x, y, base_params1) _, params2 = rof.estimate_lorentziandouble_dip( x, y, base_params2, thresh_frac, min_thresh, sigma_thresh_frac ) args.append((x, y, model1, model2, params1, params2, r2_thresh)) # Parallel fitting model_results = Parallel(n_jobs=cpu_count())( delayed(self._lorentzian_fitting)( x, y, model1, model2, params1, params2, r2_thresh) for x, y, model1, model2, params1, params2, r2_thresh in tqdm(args, disable=not progress_bar) ) x = next(iter(odmr_measurements.values())).data["Freq(MHz)"].to_numpy() x_fit = np.linspace(start=x[0], stop=x[-1], num=int(len(x) * 2)) for odmr, res in zip(odmr_measurements.values(), model_results): if len(res.params) == 6: # Evaluate a single Lorentzian y_fit = model1.eval(x=x_fit, params=res.params) else: # Evaluate a double Lorentzian y_fit = model2.eval(x=x_fit, params=res.params) # Plug results into the DataClass odmr.fit_model = res odmr.fit_data = pd.DataFrame(np.vstack((x_fit, y_fit)).T, columns=["x_fit", "y_fit"]) if extract_pixel_from_filename: # Extract the pixel with regex from the filename row, col = map( int, re.findall(r'(?<=\().*?(?=\))', odmr.filename)[0].split(",") ) odmr.xy_position = (row, col) return odmr_measurements
def get_all_fits(self) ‑> tuple[list, list]
-
Get all available fits
Returns
Tuple with list of 1d and 2d fits
Expand source code
def get_all_fits(self) -> tuple[list, list]: """Get all available fits Returns: Tuple with list of 1d and 2d fits """ one_d_fits: list = list(self.fit_list['1d'].keys()) two_d_fits: list = list(self.fit_list['2d'].keys()) self.log.info(f"1d fits: {one_d_fits}\n2d fits: {two_d_fits}") return one_d_fits, two_d_fits
def optimize_raster_odmr_params(self, measurements: dict[str, MeasurementDataclass], num_samples: int = 10, num_params: int = 3) ‑> tuple[float, tuple[float, float, float]]
-
This method optimizes the hyperparameters of the ODMR analysis. It does so by randomly sampling a subset of the measurements and then optimizing the hyperparameters for them.
Args
measurements
- A dictionary of measurements to optimize the hyperparameters.
num_params
- The number of parameters to optimize.
num_samples
- The number of measurements to sample.
Returns
The highest minimum R2 value and the optimized hyperparameters.
Expand source code
def optimize_raster_odmr_params( self, measurements: dict[str, MeasurementDataclass], num_samples: int = 10, num_params: int = 3, ) -> tuple[float, tuple[float, float, float]]: """ This method optimizes the hyperparameters of the ODMR analysis. It does so by randomly sampling a subset of the measurements and then optimizing the hyperparameters for them. Args: measurements: A dictionary of measurements to optimize the hyperparameters. num_params: The number of parameters to optimize. num_samples: The number of measurements to sample. Returns: The highest minimum R2 value and the optimized hyperparameters. """ r2_threshs: np.ndarray = np.around( np.linspace(start=0.9, stop=0.99, num=num_params), decimals=2 ) thresh_fracs: np.ndarray = np.around( np.linspace(start=0.5, stop=0.9, num=num_params), decimals=1 ) sigma_thresh_fracs: np.ndarray = np.around( np.linspace(start=0.1, stop=0.2, num=num_params), decimals=1 ) odmr_sample: dict = {} for k, v in random.sample(sorted(measurements.items()), k=num_samples): odmr_sample[k] = v highest_min_r2: float = 0 optimal_params: tuple[float, float, float] = (0, 0, 0) for r2_thresh, thresh_frac, sigma_thresh_frac in product( r2_threshs, thresh_fracs, sigma_thresh_fracs): odmr_sample = self.fit_raster_odmr( odmr_sample, r2_thresh=r2_thresh, thresh_frac=thresh_frac, sigma_thresh_frac=sigma_thresh_frac, min_thresh=0.01, progress_bar=False ) r2s: np.ndarray = np.zeros(len(odmr_sample)) for idx, odmr in enumerate(odmr_sample.values()): r2s[idx] = odmr.fit_model.rsquared min_r2: float = np.min(r2s) if highest_min_r2 < min_r2: highest_min_r2 = min_r2 optimal_params = (r2_thresh, thresh_frac, sigma_thresh_frac) return highest_min_r2, optimal_params
class DataHandler (data_folder: Path, figure_folder: Path, measurement_folder: Path = PosixPath('.'), copy_measurement_folder_structure: bool = True)
-
Handles automated data searching and extraction into dataclasses.
Args
data_folder
- Path to the data folder.
figure_folder
- Path to the figure folder.
measurement_folder
- Path to the measurement folder.
Examples
Create an instance of the DataHandler class:
>>> dh = DataHandler( >>> data_folder=Path('C:\'', 'Data'), >>> figure_folder=Path('C:\'', 'QudiHiraAnalysis'), >>> measurement_folder=Path('20230101_Bakeout'), >>> )
Expand source code
class DataHandler(DataLoader, AnalysisLogic): """ Handles automated data searching and extraction into dataclasses. Args: data_folder: Path to the data folder. figure_folder: Path to the figure folder. measurement_folder: Path to the measurement folder. Examples: Create an instance of the DataHandler class: >>> dh = DataHandler( >>> data_folder=Path('C:\\'', 'Data'), >>> figure_folder=Path('C:\\'', 'QudiHiraAnalysis'), >>> measurement_folder=Path('20230101_Bakeout'), >>> ) """ def __init__( self, data_folder: Path, figure_folder: Path, measurement_folder: Path = Path(), copy_measurement_folder_structure: bool = True, ): self.log = logging.getLogger(__name__) self.data_folder_path = self.__get_data_folder_path(data_folder, measurement_folder) if copy_measurement_folder_structure: self.figure_folder_path = self.__get_figure_folder_path(figure_folder, measurement_folder) else: self.figure_folder_path = self.__get_figure_folder_path(figure_folder, Path()) super().__init__(base_read_path=self.data_folder_path, base_write_path=self.figure_folder_path) self.timestamp_format_str = "%Y%m%d-%H%M-%S" def __get_data_folder_path(self, data_folder: Path, folder_name: Path) -> Path: """ Check if folder exists, if not, create and return absolute folder paths. """ path = data_folder / folder_name if not path.exists(): raise OSError("Data folder path does not exist.") self.log.info(f"Data folder path is {path}") return path def __get_figure_folder_path(self, figure_folder: Path, folder_name: Path) -> Path: """ Check if folder exists, if not, create and return absolute folder paths. """ path = figure_folder / folder_name if not path.exists(): path.mkdir() self.log.info(f"Creating new output folder path {path}") else: self.log.info(f"Figure folder path is {path}") return path def __tree(self, dir_path: Path, prefix: str = ''): """ A recursive generator, given a directory Path object will yield a visual tree structure line by line with each line prefixed by the same characters """ # prefix components: space = ' ' branch = '│ ' # pointers: tee = '├── ' last = '└── ' contents = list(dir_path.iterdir()) # contents each get pointers that are ├── with a final └── : pointers = [tee] * (len(contents) - 1) + [last] for pointer, path in zip(pointers, contents): yield prefix + pointer + path.name if path.is_dir(): # extend the prefix and recurse: extension = branch if pointer == tee else space # i.e. space because last, └── , above so no more | yield from self.__tree(path, prefix=prefix + extension) def __print_or_return_tree(self, folder: Path, print_tree: bool) -> str | None: """ Print or return a tree of the data and figure folders. """ if print_tree: for line in self.__tree(folder): print(line) else: tree = "" for line in self.__tree(folder): tree += line + "\n" return tree def data_folder_tree(self, print_tree: bool = True) -> str | None: """ Print or return a string tree of the data folder. Args: print_tree: Print the tree or return it as a string (default: True). Returns: str: The tree as a string if `print_tree is False. """ return self.__print_or_return_tree(self.data_folder_path, print_tree=print_tree) def figure_folder_tree(self, print_tree: bool = True) -> str | None: """ Print or return a string tree of the figure folder. Args: print_tree: Print the tree or return it as a string (default: True). Returns: str: The tree as a string if `print_tree is False. """ return self.__print_or_return_tree(self.figure_folder_path, print_tree=print_tree) def _get_measurement_filepaths( self, measurement: str, extension: str, exclude_str: str | None = None ) -> list[Path]: """ List all measurement files for a single measurement type, regardless of date within a similar set (i.e. top level folder). """ filepaths: list[Path] = [] for path in self.data_folder_path.rglob("*"): if path.is_file() and measurement.lower() in str( path).lower(): if exclude_str is None or exclude_str not in str(path): if extension: if path.suffix == extension: filepaths.append(path) else: filepaths.append(path) return filepaths def __load_qudi_measurements_into_dataclass( self, measurement_str: str, pulsed: bool, extension: str ) -> dict[str: MeasurementDataclass]: if pulsed: filtered_filepaths = [] timestamps = set() # Get set of unique timestamps containing pulsed_measurement_str for filepath in self._get_measurement_filepaths(measurement=measurement_str, extension=extension, exclude_str="image_1.dat"): timestamps.add(filepath.name[:16]) filtered_filepaths.append(filepath) pulsed_measurement_data: dict[str: MeasurementDataclass] = {} for idx, ts in enumerate(timestamps): pm, lp, rt = None, None, None for filepath in filtered_filepaths: filename = filepath.name if filename.startswith(ts): if str(filename).endswith("laser_pulses.dat"): lp = LaserPulses(filepath=filepath, loaders=self.trace_qudi_loader) elif str(filename).endswith("pulsed_measurement.dat"): pm = PulsedMeasurement(filepath=filepath, loaders=self.default_qudi_loader) elif str(filename).endswith("raw_timetrace.dat"): rt = RawTimetrace(filepath=filepath, loaders=self.trace_qudi_loader) if lp and pm and rt: break if not (lp and pm and rt): raise OSError( f"'{filtered_filepaths[idx]}' is a invalid pulsed measurement.") pulsed_measurement_data[ts] = ( MeasurementDataclass( timestamp=datetime.datetime.strptime( ts, self.timestamp_format_str ), pulsed=PulsedMeasurementDataclass( measurement=pm, laser_pulses=lp, timetrace=rt ) ) ) return pulsed_measurement_data else: if measurement_str.lower() == "confocal": loaders = self.confocal_qudi_loader exclude_str = "xy_data.dat" elif measurement_str.lower() == "pixelscanner": loaders = self.pixelscanner_qudi_loader exclude_str = None else: loaders = self.default_qudi_loader exclude_str = None measurement_data: dict[str: MeasurementDataclass] = {} for filepath in self._get_measurement_filepaths(measurement_str, extension, exclude_str): ts = filepath.name[:16] measurement_data[ts] = ( MeasurementDataclass( filepath=filepath, timestamp=datetime.datetime.strptime( ts, self.timestamp_format_str ), _loaders=loaders ) ) return measurement_data def __load_standard_measurements_into_dataclass( self, measurement_str: str, extension: str ) -> dict[str: MeasurementDataclass]: measurement_list: dict[str: MeasurementDataclass] = {} # Try and infer measurement type if measurement_str.lower() == "temperature-monitoring": loaders = self.temperature_loader extension = ".xls" exclude_str = None elif measurement_str.lower() == "pressure-monitoring": loaders = self.pressure_loader extension = ".txt" exclude_str = None elif measurement_str == "frq-sweep": loaders = self.nanonis_loader exclude_str = None elif extension == ".sxm": loaders = self.nanonis_spm_loader exclude_str = None elif extension == ".pys": loaders = self.pys_loader exclude_str = None elif extension == ".001": loaders = self.bruker_spm_loader exclude_str = None else: loaders = self.default_qudi_loader exclude_str = None for filepath in self._get_measurement_filepaths(measurement_str, extension, exclude_str): timestamp = datetime.datetime.fromtimestamp(filepath.stat().st_mtime) self.log.warning( "Extracting timestamp from file modified time, may not be accurate.") ts = datetime.datetime.strftime(timestamp, self.timestamp_format_str) measurement_list[ts] = ( MeasurementDataclass( filepath=filepath, timestamp=timestamp, _loaders=loaders ) ) return measurement_list def load_measurements( self, measurement_str: str, qudi: bool = True, pulsed: bool = False, extension: str = ".dat" ) -> dict[str: MeasurementDataclass]: """ Lazy-load all measurements of a given type into a dictionary of dataclasses. Args: measurement_str: The name of the measurement type to load e.g. t1, t2, confocal etc. Recursively searches through the path defined by data_folder and measurement_folder. Case-insensitive (i.e. "odmr" == "ODMR" == "Odmr"). qudi: Whether the measurement is a qudi measurement (default: True). pulsed: Whether the measurement is a pulsed measurement (default: False). extension: The file extension of the measurement files (default: .dat). Returns: dict: A dictionary where keys are the measurement timestamps and values are dataclasses containing the measurement data. Examples: `dh` is an instance of the `DataHandler` class. Load all ODMR measurements: >>> dh.load_measurements(measurement_str="ODMR", pulsed=True) Load all confocal data: >>> dh.load_measurements(measurement_str="Confocal") Load all temperature monitoring data: >>> dh.load_measurements(measurement_str="Temperature") Load all pressure monitoring data: >>> dh.load_measurements(measurement_str="Pressure") """ measurement_str = measurement_str.lower() if qudi: return self.__load_qudi_measurements_into_dataclass( measurement_str, pulsed=pulsed, extension=".dat" ) else: return self.__load_standard_measurements_into_dataclass( measurement_str, extension=extension )
Ancestors
- DataLoader
- IOHandler
- AnalysisLogic
- qudi_hira_analysis._qudi_fit_logic.FitLogic
Methods
def data_folder_tree(self, print_tree: bool = True) ‑> str | None
-
Print or return a string tree of the data folder.
Args
print_tree
- Print the tree or return it as a string (default: True).
Returns
str
- The tree as a string if `print_tree is False.
Expand source code
def data_folder_tree(self, print_tree: bool = True) -> str | None: """ Print or return a string tree of the data folder. Args: print_tree: Print the tree or return it as a string (default: True). Returns: str: The tree as a string if `print_tree is False. """ return self.__print_or_return_tree(self.data_folder_path, print_tree=print_tree)
def figure_folder_tree(self, print_tree: bool = True) ‑> str | None
-
Print or return a string tree of the figure folder.
Args
print_tree
- Print the tree or return it as a string (default: True).
Returns
str
- The tree as a string if `print_tree is False.
Expand source code
def figure_folder_tree(self, print_tree: bool = True) -> str | None: """ Print or return a string tree of the figure folder. Args: print_tree: Print the tree or return it as a string (default: True). Returns: str: The tree as a string if `print_tree is False. """ return self.__print_or_return_tree(self.figure_folder_path, print_tree=print_tree)
def load_measurements(self, measurement_str: str, qudi: bool = True, pulsed: bool = False, extension: str = '.dat') ‑> dict[slice(
, MeasurementDataclass'>, None)] -
Lazy-load all measurements of a given type into a dictionary of dataclasses.
Args
measurement_str
- The name of the measurement type to load e.g. t1, t2, confocal etc. Recursively searches through the path defined by data_folder and measurement_folder. Case-insensitive (i.e. "odmr" == "ODMR" == "Odmr").
qudi
- Whether the measurement is a qudi measurement (default: True).
pulsed
- Whether the measurement is a pulsed measurement (default: False).
extension
- The file extension of the measurement files (default: .dat).
Returns
dict
- A dictionary where keys are the measurement timestamps and values are dataclasses containing the measurement data.
Examples
dh
is an instance of theDataHandler
class.Load all ODMR measurements:
>>> dh.load_measurements(measurement_str="ODMR", pulsed=True)
Load all confocal data:
>>> dh.load_measurements(measurement_str="Confocal")
Load all temperature monitoring data:
>>> dh.load_measurements(measurement_str="Temperature")
Load all pressure monitoring data:
>>> dh.load_measurements(measurement_str="Pressure")
Expand source code
def load_measurements( self, measurement_str: str, qudi: bool = True, pulsed: bool = False, extension: str = ".dat" ) -> dict[str: MeasurementDataclass]: """ Lazy-load all measurements of a given type into a dictionary of dataclasses. Args: measurement_str: The name of the measurement type to load e.g. t1, t2, confocal etc. Recursively searches through the path defined by data_folder and measurement_folder. Case-insensitive (i.e. "odmr" == "ODMR" == "Odmr"). qudi: Whether the measurement is a qudi measurement (default: True). pulsed: Whether the measurement is a pulsed measurement (default: False). extension: The file extension of the measurement files (default: .dat). Returns: dict: A dictionary where keys are the measurement timestamps and values are dataclasses containing the measurement data. Examples: `dh` is an instance of the `DataHandler` class. Load all ODMR measurements: >>> dh.load_measurements(measurement_str="ODMR", pulsed=True) Load all confocal data: >>> dh.load_measurements(measurement_str="Confocal") Load all temperature monitoring data: >>> dh.load_measurements(measurement_str="Temperature") Load all pressure monitoring data: >>> dh.load_measurements(measurement_str="Pressure") """ measurement_str = measurement_str.lower() if qudi: return self.__load_qudi_measurements_into_dataclass( measurement_str, pulsed=pulsed, extension=".dat" ) else: return self.__load_standard_measurements_into_dataclass( measurement_str, extension=extension )
Inherited members
DataLoader
:read_bruker_spm_data
read_confocal_into_dataframe
read_csv
read_excel
read_into_dataframe
read_into_ndarray
read_into_ndarray_transposed
read_lakeshore_data
read_nanonis_data
read_nanonis_parameters
read_nanonis_spm_data
read_oceanoptics_data
read_pfeiffer_data
read_pixelscanner_data
read_pkl
read_pys
read_qudi_parameters
save_df
save_figures
save_pkl
save_pys
AnalysisLogic
:
class FitMethodsAndEstimators
-
Class for storing fit methods and estimators. Fit methods are stored as tuples of (method, estimator) where method is the name of the fit method and estimator is the name of the estimator.
The fit functions available are:
Dimension Fit 1D decayexponential biexponential decayexponentialstretched gaussian gaussiandouble gaussianlinearoffset hyperbolicsaturation linear lorentzian lorentziandouble lorentziantriple sine sinedouble sinedoublewithexpdecay sinedoublewithtwoexpdecay sineexponentialdecay sinestretchedexponentialdecay sinetriple sinetriplewithexpdecay sinetriplewiththreeexpdecay 2D twoDgaussian Expand source code
class FitMethodsAndEstimators: """ Class for storing fit methods and estimators. Fit methods are stored as tuples of (method, estimator) where method is the name of the fit method and estimator is the name of the estimator. The fit functions available are: | Dimension | Fit | |-----------|-------------------------------| | 1D | decayexponential | | | biexponential | | | decayexponentialstretched | | | gaussian | | | gaussiandouble | | | gaussianlinearoffset | | | hyperbolicsaturation | | | linear | | | lorentzian | | | lorentziandouble | | | lorentziantriple | | | sine | | | sinedouble | | | sinedoublewithexpdecay | | | sinedoublewithtwoexpdecay | | | sineexponentialdecay | | | sinestretchedexponentialdecay | | | sinetriple | | | sinetriplewithexpdecay | | | sinetriplewiththreeexpdecay | | 2D | twoDgaussian | """ # Fit methods with corresponding estimators antibunching: tuple = ("antibunching", "dip") hyperbolicsaturation: tuple = ("hyperbolicsaturation", "generic") lorentzian: tuple = ("lorentzian", "dip") lorentziandouble: tuple = ("lorentziandouble", "dip") sineexponentialdecay: tuple = ("sineexponentialdecay", "generic") decayexponential: tuple = ("decayexponential", "generic") gaussian: tuple = ("gaussian", "dip") gaussiandouble: tuple = ("gaussiandouble", "dip") gaussianlinearoffset: tuple = ("gaussianlinearoffset", "dip") lorentziantriple: tuple = ("lorentziantriple", "dip") biexponential: tuple = ("biexponential", "generic") decayexponentialstretched: tuple = ("decayexponentialstretched", "generic") linear: tuple = ("linear", "generic") sine: tuple = ("sine", "generic") sinedouble: tuple = ("sinedouble", "generic") sinedoublewithexpdecay: tuple = ("sinedoublewithexpdecay", "generic") sinedoublewithtwoexpdecay: tuple = ("sinedoublewithtwoexpdecay", "generic") sinestretchedexponentialdecay: tuple = ("sinestretchedexponentialdecay", "generic") sinetriple: tuple = ("sinetriple", "generic") sinetriplewithexpdecay: tuple = ("sinetriplewithexpdecay", "generic") sinetriplewiththreeexpdecay: tuple = ("sinetriplewiththreeexpdecay", "generic") twoDgaussian: tuple = ("twoDgaussian", "generic") # noqa: N815
Class variables
var antibunching : tuple
var biexponential : tuple
var decayexponential : tuple
var decayexponentialstretched : tuple
var gaussian : tuple
var gaussiandouble : tuple
var gaussianlinearoffset : tuple
var hyperbolicsaturation : tuple
var linear : tuple
var lorentzian : tuple
var lorentziandouble : tuple
var lorentziantriple : tuple
var sine : tuple
var sinedouble : tuple
var sinedoublewithexpdecay : tuple
var sinedoublewithtwoexpdecay : tuple
var sineexponentialdecay : tuple
var sinestretchedexponentialdecay : tuple
var sinetriple : tuple
var sinetriplewithexpdecay : tuple
var sinetriplewiththreeexpdecay : tuple
var twoDgaussian : tuple
class IOHandler (base_read_path: Optional[pathlib.Path] = None, base_write_path: Optional[pathlib.Path] = None)
-
Handle all read and write operations.
Expand source code
class IOHandler: """ Handle all read and write operations. """ def __init__( self, base_read_path: Optional[Path] = None, base_write_path: Optional[Path] = None ): super().__init__() self.base_read_path = base_read_path self.base_write_path = base_write_path @staticmethod def _add_base_read_path(func: Callable) -> Callable: """ Decorator to add the `base_read_path` to the filepath if it is not None Args: func: Function to be decorated Returns: Decorated function """ @wraps(func) def wrapper(self, filepath: Path, **kwargs): if self.base_read_path: filepath = self.base_read_path / filepath return func(self, filepath, **kwargs) return wrapper @staticmethod def _add_base_write_path(func: Callable) -> Callable: """ Decorator to add the `base_write_path` to the filepath if it is not None Args: func: Function to be decorated Returns: Decorated function """ @wraps(func) def wrapper(self, filepath: Path, **kwargs): if self.base_write_path: filepath = self.base_write_path / filepath filepath.parent.mkdir(exist_ok=True) return func(self, filepath, **kwargs) return wrapper @staticmethod def _check_extension(ext: str) -> Callable: """ Decorator to check the extension of the filepath is correct Args: ext: Extension to check for Returns: Decorated function """ def decorator(func: Callable) -> Callable: @wraps(func) def wrapper(self, filepath: Path, **kwargs) -> Callable: if filepath.suffix == ext: return func(self, filepath, **kwargs) elif filepath.suffix == "": return func(self, filepath.with_suffix(ext), **kwargs) else: raise OSError( f"Invalid extension '{filepath.suffix}' in '{filepath}', " f"extension should be '{ext}'") return wrapper return decorator @_add_base_read_path @_check_extension(".dat") def read_qudi_parameters(self, filepath: Path) -> dict: """Extract parameters from a qudi dat file. Args: filepath: Path to the qudi .dat file Returns: Dictionary of parameters """ params = {} with open(filepath) as file: for line in file: if line == '#=====\n': break else: # noinspection PyBroadException try: # Remove # from beginning of lines line = line[1:] if line.count(":") == 1: # Add params to dictionary label, value = line.split(":") if value != "\n": params[label] = ast.literal_eval( inspect.cleandoc(value)) elif line.count(":") == 3: # Handle files with timestamps in them label = line.split(":")[0] timestamp_str = "".join(line.split(":")[1:]).strip() datetime_str = datetime.datetime.strptime( timestamp_str, "%d.%m.%Y %Hh%Mmin%Ss" ).replace(tzinfo=datetime.timezone.utc) params[label] = datetime_str except Exception as _: pass return params @_add_base_read_path @_check_extension(".dat") def read_into_dataframe(self, filepath: Path) -> pd.DataFrame: """Read a qudi data file into a pandas DataFrame for analysis. Args: filepath: Path to the qudi data file Returns: DataFrame containing the data from the qudi data file """ with open(filepath) as handle: # Generate column names for DataFrame by parsing the file *_comments, names = itertools.takewhile(lambda line: line.startswith('#'), handle) names = names[1:].strip().split("\t") return pd.read_csv(filepath, names=names, comment="#", sep="\t") @_add_base_read_path def read_csv(self, filepath: Path, **kwargs) -> pd.DataFrame: """ Read a csv file into a pandas DataFrame. """ return pd.read_csv(filepath, **kwargs) @_add_base_read_path def read_excel(self, filepath: Path, **kwargs) -> pd.DataFrame: """ Read a csv file into a pandas DataFrame. """ return pd.read_excel(filepath, **kwargs) @_add_base_read_path @_check_extension(".dat") def read_confocal_into_dataframe(self, filepath: Path) -> pd.DataFrame: """ Read a qudi confocal data file into a pandas DataFrame for analysis. """ confocal_params = self.read_qudi_parameters(filepath) data = self.read_into_ndarray(filepath, delimiter="\t") # Use the confocal parameters to generate the index & columns for the DataFrame index = np.linspace( confocal_params['X image min (m)'], confocal_params['X image max (m)'], data.shape[0] ) columns = np.linspace( confocal_params['Y image min'], confocal_params['Y image max'], data.shape[1] ) df = pd.DataFrame(data, index=index, columns=columns) # Sort the index to get origin (0, 0) in the lower left corner of the DataFrame df.sort_index(axis=0, ascending=False, inplace=True) return df @_add_base_read_path def read_into_ndarray(self, filepath: Path, **kwargs) -> np.ndarray: """ Read a file into a numpy ndarray. """ return np.genfromtxt(filepath, **kwargs) @_add_base_read_path def read_into_ndarray_transposed(self, filepath: Path, **kwargs) -> np.ndarray: """ Read a file into a transposed numpy ndarray. """ return np.genfromtxt(filepath, **kwargs).T @_add_base_read_path @_check_extension(".pys") def read_pys(self, filepath: Path) -> dict: """ Read raw .pys data files into a dictionary. """ byte_dict = np.load(str(filepath), encoding="bytes", allow_pickle=True) # Convert byte string keys to normal strings return {key.decode('utf8'): byte_dict.get(key) for key in byte_dict} @_add_base_read_path @_check_extension(".pkl") def read_pkl(self, filepath: Path) -> dict: """ Read pickle files into a dictionary. """ with open(filepath, 'rb') as f: file = pickle.load(f) return file @_add_base_read_path @_check_extension(".dat") def read_nanonis_data(self, filepath: Path) -> pd.DataFrame: """Read data from a Nanonis .dat file. Args: filepath: Path to the Nanonis .dat file. Returns: DataFrame of data. """ skip_rows = 0 with open(filepath) as dat_file: for num, line in enumerate(dat_file, 1): if "[DATA]" in line: # Find number of rows to skip when extracting data skip_rows = num break if "#=====" in line: skip_rows = num break df = pd.read_table(filepath, sep="\t", skiprows=skip_rows) return df @_add_base_read_path @_check_extension(".dat") def read_nanonis_parameters(self, filepath: Path) -> dict: """Read parameters from a Nanonis .dat file. Args: filepath: Path to the Nanonis .dat file. Returns: Dictionary of parameters. """ parameters = {} with open(filepath) as dat_file: for line in dat_file: if line == "\n": # Break when reaching empty line break elif "User" in line or line.split("\t")[0] == "": # Cleanup excess parameters and skip empty lines pass else: label, value, _ = line.split("\t") with contextlib.suppress(ValueError): value = float(value) if "Oscillation Control>" in label: label = label.replace("Oscillation Control>", "") parameters[label] = value return parameters @_add_base_read_path @_check_extension(".sxm") def read_nanonis_spm_data(self, filepath: Path) -> pySPM.SXM: """Read a Nanonis .sxm data file. Args: filepath: Path to the .sxm file. Returns: pySPM.SXM object containing the data. """ return pySPM.SXM(filepath) @_add_base_read_path @_check_extension(".001") def read_bruker_spm_data(self, filepath: Path) -> pySPM.Bruker: """Read a Bruker SPM data file. Args: filepath: Path to the .001 file. Returns: pySPM.Bruker object containing the data. """ return pySPM.Bruker(filepath) @_add_base_read_path @_check_extension(".txt") def read_pfeiffer_data(self, filepath: Path) -> pd.DataFrame: """Read data stored by Pfeiffer vacuum monitoring software. Args: filepath: Path to the text file. Returns: DataFrame containing the data. """ # Extract rows including the header df = pd.read_csv(filepath, sep="\t", skiprows=[0, 2, 3, 4]) # Combine data and time columns together df["Date"] = df["Date"] + " " + df["Time"] df = df.drop("Time", axis=1) # Infer datetime format and convert to datetime objects df["Date"] = pd.to_datetime(df["Date"], infer_datetime_format=True) # Set datetime as index df = df.set_index("Date", drop=True) return df @_add_base_read_path @_check_extension(".xls") def read_lakeshore_data(self, filepath: Path) -> pd.DataFrame: """Read data stored by Lakeshore temperature monitor software. Args: filepath: Path to the Excel file. Returns: DataFrame containing the data. """ # Extract only the origin timestamp origin = pd.read_excel( filepath, skiprows=1, nrows=1, usecols=[1], header=None )[1][0] # Remove any tzinfo to prevent future exceptions in pandas origin = origin.replace("CET", "") # Parse datetime object from timestamp origin = pd.to_datetime(origin) # Create DataFrame and drop empty cols df = pd.read_excel(filepath, skiprows=3) df = df.dropna(axis=1, how="all") # Add datetimes to DataFrame df["Datetime"] = pd.to_datetime(df["Time"], unit="ms", origin=origin) df = df.drop("Time", axis=1) # Set datetime as index df = df.set_index("Datetime", drop=True) return df @_add_base_read_path @_check_extension(".txt") def read_oceanoptics_data(self, filepath: str) -> pd.DataFrame: """Read spectrometer data from OceanOptics spectrometer. Args: filepath: Path to the data file. Returns: DataFrame containing the wavelength and intensity data. """ df = pd.read_csv(filepath, sep="\t", skiprows=14, names=["wavelength", "intensity"]) return df @staticmethod def __get_forward_backward_counts(count_rates, num_pixels): split_array = np.split(count_rates, 2 * num_pixels) # Extract forward scan array as every second element forward_counts = np.stack(split_array[::2]) # Extract backward scan array as every shifted second element # Flip scan so that backward and forward scans represent the same data backward_counts = np.flip(np.stack(split_array[1::2]), axis=1) return forward_counts, backward_counts def read_pixelscanner_data(self, filepath: Path) -> ( pySPM.SPM_image, pySPM.SPM_image): """ Read data from a PixelScanner measurement. Args: filepath: Path to the data file. Returns: Forward and backward scan data. """ df = self.read_into_dataframe(filepath) num_pixels = int(np.sqrt(len(df) // 2)) if num_pixels ** 2 != len(df) // 2: raise ValueError("Number of pixels does not match data length.") try: fwd, bwd = self.__get_forward_backward_counts(df["count_rates"], num_pixels) except KeyError: try: fwd, bwd = self.__get_forward_backward_counts(df["Count Rates (cps)"], num_pixels) except KeyError: # Support old data format fwd = df["forward (cps)"].to_numpy().reshape(num_pixels, num_pixels) bwd = df["backward (cps)"].to_numpy().reshape(num_pixels, num_pixels) fwd = pySPM.SPM_image(fwd, channel="Forward", _type="NV-PL") bwd = pySPM.SPM_image(bwd, channel="Backward", _type="NV-PL") return fwd, bwd @_add_base_write_path @_check_extension(".pkl") def save_pkl(self, filepath: Path, obj: object): """Saves pickle files. Args: filepath: Path to the data file. obj: Object to be saved. """ with open(filepath, 'wb') as f: pickle.dump(obj, f) @_add_base_write_path @_check_extension(".pys") def save_pys(self, filepath: Path, dictionary: dict): """Saves .pys files. Args: filepath: Path to the data file. dictionary: Dictionary to be saved. """ with open(filepath, 'wb') as f: pickle.dump(dictionary, f, 1) @_add_base_write_path @_check_extension(".pys") def save_df(self, filepath: Path, df: pd.DataFrame): """ Save Dataframe as csv. """ df.to_csv(filepath, sep='\t', encoding='utf-8') @_add_base_write_path def save_figures(self, filepath: Path, fig: plt.Figure, **kwargs): """Saves figures from matplotlib plot data. By default, saves as jpg, png, pdf and svg. Args: fig: Matplotlib figure to save. filepath: Name of figure to save. only_jpg: If True, only save as jpg (default: False). only_pdf: If True, only save as pdf (default: False). **kwargs: Keyword arguments passed to matplotlib.pyplot.savefig(). """ extensions = None if "only_jpg" in kwargs: if kwargs.get("only_jpg"): extensions = [".jpg"] kwargs.pop("only_jpg", None) elif "only_pdf" in kwargs: if kwargs.get("only_pdf"): extensions = [".pdf"] kwargs.pop("only_pdf", None) else: extensions = [".jpg", ".pdf", ".svg", ".png"] for ext in extensions: fig.savefig(filepath.with_suffix(ext), dpi=200, **kwargs)
Subclasses
Methods
def read_bruker_spm_data(self, filepath: pathlib.Path) ‑> pySPM.Bruker.Bruker
-
Read a Bruker SPM data file.
Args
filepath
- Path to the .001 file.
Returns
pySPM.Bruker object containing the data.
Expand source code
@_add_base_read_path @_check_extension(".001") def read_bruker_spm_data(self, filepath: Path) -> pySPM.Bruker: """Read a Bruker SPM data file. Args: filepath: Path to the .001 file. Returns: pySPM.Bruker object containing the data. """ return pySPM.Bruker(filepath)
def read_confocal_into_dataframe(self, filepath: pathlib.Path) ‑> pandas.core.frame.DataFrame
-
Read a qudi confocal data file into a pandas DataFrame for analysis.
Expand source code
@_add_base_read_path @_check_extension(".dat") def read_confocal_into_dataframe(self, filepath: Path) -> pd.DataFrame: """ Read a qudi confocal data file into a pandas DataFrame for analysis. """ confocal_params = self.read_qudi_parameters(filepath) data = self.read_into_ndarray(filepath, delimiter="\t") # Use the confocal parameters to generate the index & columns for the DataFrame index = np.linspace( confocal_params['X image min (m)'], confocal_params['X image max (m)'], data.shape[0] ) columns = np.linspace( confocal_params['Y image min'], confocal_params['Y image max'], data.shape[1] ) df = pd.DataFrame(data, index=index, columns=columns) # Sort the index to get origin (0, 0) in the lower left corner of the DataFrame df.sort_index(axis=0, ascending=False, inplace=True) return df
def read_csv(self, filepath: pathlib.Path, **kwargs) ‑> pandas.core.frame.DataFrame
-
Read a csv file into a pandas DataFrame.
Expand source code
@_add_base_read_path def read_csv(self, filepath: Path, **kwargs) -> pd.DataFrame: """ Read a csv file into a pandas DataFrame. """ return pd.read_csv(filepath, **kwargs)
def read_excel(self, filepath: pathlib.Path, **kwargs) ‑> pandas.core.frame.DataFrame
-
Read a csv file into a pandas DataFrame.
Expand source code
@_add_base_read_path def read_excel(self, filepath: Path, **kwargs) -> pd.DataFrame: """ Read a csv file into a pandas DataFrame. """ return pd.read_excel(filepath, **kwargs)
def read_into_dataframe(self, filepath: pathlib.Path) ‑> pandas.core.frame.DataFrame
-
Read a qudi data file into a pandas DataFrame for analysis.
Args
filepath
- Path to the qudi data file
Returns
DataFrame containing the data from the qudi data file
Expand source code
@_add_base_read_path @_check_extension(".dat") def read_into_dataframe(self, filepath: Path) -> pd.DataFrame: """Read a qudi data file into a pandas DataFrame for analysis. Args: filepath: Path to the qudi data file Returns: DataFrame containing the data from the qudi data file """ with open(filepath) as handle: # Generate column names for DataFrame by parsing the file *_comments, names = itertools.takewhile(lambda line: line.startswith('#'), handle) names = names[1:].strip().split("\t") return pd.read_csv(filepath, names=names, comment="#", sep="\t")
def read_into_ndarray(self, filepath: pathlib.Path, **kwargs) ‑> numpy.ndarray
-
Read a file into a numpy ndarray.
Expand source code
@_add_base_read_path def read_into_ndarray(self, filepath: Path, **kwargs) -> np.ndarray: """ Read a file into a numpy ndarray. """ return np.genfromtxt(filepath, **kwargs)
def read_into_ndarray_transposed(self, filepath: pathlib.Path, **kwargs) ‑> numpy.ndarray
-
Read a file into a transposed numpy ndarray.
Expand source code
@_add_base_read_path def read_into_ndarray_transposed(self, filepath: Path, **kwargs) -> np.ndarray: """ Read a file into a transposed numpy ndarray. """ return np.genfromtxt(filepath, **kwargs).T
def read_lakeshore_data(self, filepath: pathlib.Path) ‑> pandas.core.frame.DataFrame
-
Read data stored by Lakeshore temperature monitor software.
Args
filepath
- Path to the Excel file.
Returns
DataFrame containing the data.
Expand source code
@_add_base_read_path @_check_extension(".xls") def read_lakeshore_data(self, filepath: Path) -> pd.DataFrame: """Read data stored by Lakeshore temperature monitor software. Args: filepath: Path to the Excel file. Returns: DataFrame containing the data. """ # Extract only the origin timestamp origin = pd.read_excel( filepath, skiprows=1, nrows=1, usecols=[1], header=None )[1][0] # Remove any tzinfo to prevent future exceptions in pandas origin = origin.replace("CET", "") # Parse datetime object from timestamp origin = pd.to_datetime(origin) # Create DataFrame and drop empty cols df = pd.read_excel(filepath, skiprows=3) df = df.dropna(axis=1, how="all") # Add datetimes to DataFrame df["Datetime"] = pd.to_datetime(df["Time"], unit="ms", origin=origin) df = df.drop("Time", axis=1) # Set datetime as index df = df.set_index("Datetime", drop=True) return df
def read_nanonis_data(self, filepath: pathlib.Path) ‑> pandas.core.frame.DataFrame
-
Read data from a Nanonis .dat file.
Args
filepath
- Path to the Nanonis .dat file.
Returns
DataFrame of data.
Expand source code
@_add_base_read_path @_check_extension(".dat") def read_nanonis_data(self, filepath: Path) -> pd.DataFrame: """Read data from a Nanonis .dat file. Args: filepath: Path to the Nanonis .dat file. Returns: DataFrame of data. """ skip_rows = 0 with open(filepath) as dat_file: for num, line in enumerate(dat_file, 1): if "[DATA]" in line: # Find number of rows to skip when extracting data skip_rows = num break if "#=====" in line: skip_rows = num break df = pd.read_table(filepath, sep="\t", skiprows=skip_rows) return df
def read_nanonis_parameters(self, filepath: pathlib.Path) ‑> dict
-
Read parameters from a Nanonis .dat file.
Args
filepath
- Path to the Nanonis .dat file.
Returns
Dictionary of parameters.
Expand source code
@_add_base_read_path @_check_extension(".dat") def read_nanonis_parameters(self, filepath: Path) -> dict: """Read parameters from a Nanonis .dat file. Args: filepath: Path to the Nanonis .dat file. Returns: Dictionary of parameters. """ parameters = {} with open(filepath) as dat_file: for line in dat_file: if line == "\n": # Break when reaching empty line break elif "User" in line or line.split("\t")[0] == "": # Cleanup excess parameters and skip empty lines pass else: label, value, _ = line.split("\t") with contextlib.suppress(ValueError): value = float(value) if "Oscillation Control>" in label: label = label.replace("Oscillation Control>", "") parameters[label] = value return parameters
def read_nanonis_spm_data(self, filepath: pathlib.Path) ‑> pySPM.SXM.SXM
-
Read a Nanonis .sxm data file.
Args
filepath
- Path to the .sxm file.
Returns
pySPM.SXM object containing the data.
Expand source code
@_add_base_read_path @_check_extension(".sxm") def read_nanonis_spm_data(self, filepath: Path) -> pySPM.SXM: """Read a Nanonis .sxm data file. Args: filepath: Path to the .sxm file. Returns: pySPM.SXM object containing the data. """ return pySPM.SXM(filepath)
def read_oceanoptics_data(self, filepath: str) ‑> pandas.core.frame.DataFrame
-
Read spectrometer data from OceanOptics spectrometer.
Args
filepath
- Path to the data file.
Returns
DataFrame containing the wavelength and intensity data.
Expand source code
@_add_base_read_path @_check_extension(".txt") def read_oceanoptics_data(self, filepath: str) -> pd.DataFrame: """Read spectrometer data from OceanOptics spectrometer. Args: filepath: Path to the data file. Returns: DataFrame containing the wavelength and intensity data. """ df = pd.read_csv(filepath, sep="\t", skiprows=14, names=["wavelength", "intensity"]) return df
def read_pfeiffer_data(self, filepath: pathlib.Path) ‑> pandas.core.frame.DataFrame
-
Read data stored by Pfeiffer vacuum monitoring software.
Args
filepath
- Path to the text file.
Returns
DataFrame containing the data.
Expand source code
@_add_base_read_path @_check_extension(".txt") def read_pfeiffer_data(self, filepath: Path) -> pd.DataFrame: """Read data stored by Pfeiffer vacuum monitoring software. Args: filepath: Path to the text file. Returns: DataFrame containing the data. """ # Extract rows including the header df = pd.read_csv(filepath, sep="\t", skiprows=[0, 2, 3, 4]) # Combine data and time columns together df["Date"] = df["Date"] + " " + df["Time"] df = df.drop("Time", axis=1) # Infer datetime format and convert to datetime objects df["Date"] = pd.to_datetime(df["Date"], infer_datetime_format=True) # Set datetime as index df = df.set_index("Date", drop=True) return df
def read_pixelscanner_data(self, filepath: pathlib.Path) ‑> (
, ) -
Read data from a PixelScanner measurement.
Args
filepath
- Path to the data file.
Returns
Forward and backward scan data.
Expand source code
def read_pixelscanner_data(self, filepath: Path) -> ( pySPM.SPM_image, pySPM.SPM_image): """ Read data from a PixelScanner measurement. Args: filepath: Path to the data file. Returns: Forward and backward scan data. """ df = self.read_into_dataframe(filepath) num_pixels = int(np.sqrt(len(df) // 2)) if num_pixels ** 2 != len(df) // 2: raise ValueError("Number of pixels does not match data length.") try: fwd, bwd = self.__get_forward_backward_counts(df["count_rates"], num_pixels) except KeyError: try: fwd, bwd = self.__get_forward_backward_counts(df["Count Rates (cps)"], num_pixels) except KeyError: # Support old data format fwd = df["forward (cps)"].to_numpy().reshape(num_pixels, num_pixels) bwd = df["backward (cps)"].to_numpy().reshape(num_pixels, num_pixels) fwd = pySPM.SPM_image(fwd, channel="Forward", _type="NV-PL") bwd = pySPM.SPM_image(bwd, channel="Backward", _type="NV-PL") return fwd, bwd
def read_pkl(self, filepath: pathlib.Path) ‑> dict
-
Read pickle files into a dictionary.
Expand source code
@_add_base_read_path @_check_extension(".pkl") def read_pkl(self, filepath: Path) -> dict: """ Read pickle files into a dictionary. """ with open(filepath, 'rb') as f: file = pickle.load(f) return file
def read_pys(self, filepath: pathlib.Path) ‑> dict
-
Read raw .pys data files into a dictionary.
Expand source code
@_add_base_read_path @_check_extension(".pys") def read_pys(self, filepath: Path) -> dict: """ Read raw .pys data files into a dictionary. """ byte_dict = np.load(str(filepath), encoding="bytes", allow_pickle=True) # Convert byte string keys to normal strings return {key.decode('utf8'): byte_dict.get(key) for key in byte_dict}
def read_qudi_parameters(self, filepath: pathlib.Path) ‑> dict
-
Extract parameters from a qudi dat file.
Args
filepath
- Path to the qudi .dat file
Returns
Dictionary of parameters
Expand source code
@_add_base_read_path @_check_extension(".dat") def read_qudi_parameters(self, filepath: Path) -> dict: """Extract parameters from a qudi dat file. Args: filepath: Path to the qudi .dat file Returns: Dictionary of parameters """ params = {} with open(filepath) as file: for line in file: if line == '#=====\n': break else: # noinspection PyBroadException try: # Remove # from beginning of lines line = line[1:] if line.count(":") == 1: # Add params to dictionary label, value = line.split(":") if value != "\n": params[label] = ast.literal_eval( inspect.cleandoc(value)) elif line.count(":") == 3: # Handle files with timestamps in them label = line.split(":")[0] timestamp_str = "".join(line.split(":")[1:]).strip() datetime_str = datetime.datetime.strptime( timestamp_str, "%d.%m.%Y %Hh%Mmin%Ss" ).replace(tzinfo=datetime.timezone.utc) params[label] = datetime_str except Exception as _: pass return params
def save_df(self, filepath: pathlib.Path, df: pandas.core.frame.DataFrame)
-
Save Dataframe as csv.
Expand source code
@_add_base_write_path @_check_extension(".pys") def save_df(self, filepath: Path, df: pd.DataFrame): """ Save Dataframe as csv. """ df.to_csv(filepath, sep='\t', encoding='utf-8')
def save_figures(self, filepath: pathlib.Path, fig: matplotlib.figure.Figure, **kwargs)
-
Saves figures from matplotlib plot data.
By default, saves as jpg, png, pdf and svg.
Args
fig
- Matplotlib figure to save.
filepath
- Name of figure to save.
only_jpg
- If True, only save as jpg (default: False).
only_pdf
- If True, only save as pdf (default: False).
**kwargs
- Keyword arguments passed to matplotlib.pyplot.savefig().
Expand source code
@_add_base_write_path def save_figures(self, filepath: Path, fig: plt.Figure, **kwargs): """Saves figures from matplotlib plot data. By default, saves as jpg, png, pdf and svg. Args: fig: Matplotlib figure to save. filepath: Name of figure to save. only_jpg: If True, only save as jpg (default: False). only_pdf: If True, only save as pdf (default: False). **kwargs: Keyword arguments passed to matplotlib.pyplot.savefig(). """ extensions = None if "only_jpg" in kwargs: if kwargs.get("only_jpg"): extensions = [".jpg"] kwargs.pop("only_jpg", None) elif "only_pdf" in kwargs: if kwargs.get("only_pdf"): extensions = [".pdf"] kwargs.pop("only_pdf", None) else: extensions = [".jpg", ".pdf", ".svg", ".png"] for ext in extensions: fig.savefig(filepath.with_suffix(ext), dpi=200, **kwargs)
def save_pkl(self, filepath: pathlib.Path, obj: object)
-
Saves pickle files.
Args
filepath
- Path to the data file.
obj
- Object to be saved.
Expand source code
@_add_base_write_path @_check_extension(".pkl") def save_pkl(self, filepath: Path, obj: object): """Saves pickle files. Args: filepath: Path to the data file. obj: Object to be saved. """ with open(filepath, 'wb') as f: pickle.dump(obj, f)
def save_pys(self, filepath: pathlib.Path, dictionary: dict)
-
Saves .pys files.
Args
filepath
- Path to the data file.
dictionary
- Dictionary to be saved.
Expand source code
@_add_base_write_path @_check_extension(".pys") def save_pys(self, filepath: Path, dictionary: dict): """Saves .pys files. Args: filepath: Path to the data file. dictionary: Dictionary to be saved. """ with open(filepath, 'wb') as f: pickle.dump(dictionary, f, 1)