Add --num-* arguments to calibcant-calibrate.py.
[calibcant.git] / calibcant / analyze.py
index f8d1bdd6b7620d657e2ecb0fab7e498997b79049..ec33d5687c8d5302303ed5a65b911f2cd0ef3d33 100644 (file)
@@ -224,13 +224,22 @@ def analyze_all(config, data, raw_data, maximum_relative_error=1e-5,
     if not data.get('vibrations', None):
         data['vibration'] = _numpy.zeros(
                 (config['num-vibrations'],), dtype=float)
+    if 'raw' not in data:
+        data['raw'] = {}
+    if 'bump' not in data['raw']:
+        data['raw']['bump'] = _numpy.zeros((config['num-bumps'],), dtype=float)
+    if 'temperature' not in data['raw']:
+        data['raw']['temperature'] = _numpy.zeros(
+        (config['num-temperatures'],), dtype=float)
+    if 'vibration' not in data['raw']:
+        data['raw']['vibration'] = _numpy.zeros(
+        (config['num-vibrations'],), dtype=float)
     axis_config = config['afm']['piezo'].select_config(
         setting_name='axes',
         attribute_value=config['afm']['main-axis'],
         get_attribute=_get_axis_name)
     input_config = config['afm']['piezo'].select_config(
         setting_name='inputs', attribute_value='deflection')
-    bumps_changed = temperatures_changed = vibrations_changed = False
     calibration_group = None
     if not isinstance(group, _h5py.Group) and not dry_run:
         f = _h5py.File(filename, mode='a')
@@ -238,9 +247,8 @@ def analyze_all(config, data, raw_data, maximum_relative_error=1e-5,
     else:
         f = None
     try:
-        if len(data.get('raw', {}).get('bump', [])) != len(data['bump']):
-            bumps_changed = True
-        for i,bump in enumerate(raw_data['bump']):
+        bumps_changed = len(data['raw']['bump']) != len(data['bump'])
+        for i,bump in enumerate(raw_data.get('bump', [])):  # compare values
             data['bump'][i],changed = check_bump(
                 index=i, bump=bump, config=config, z_axis_config=axis_config,
                 deflection_channel_config=input_config, plot=plot,
@@ -249,10 +257,9 @@ def analyze_all(config, data, raw_data, maximum_relative_error=1e-5,
                 bumps_changed = True
                 bump_group = _h5_create_group(group, 'bump/{}'.format(i))
                 _bump_save(group=bump_group, processed=data['bump'][i])
-        if len(data.get('raw', {}).get('temperature', [])
-               ) != len(data['temperature']):
-            temperatures_changed = True
-        for i,temperature in enumerate(raw_data['temperature']):
+        temperatures_changed = len(data['raw']['temperature']) != len(
+            data['temperature'])
+        for i,temperature in enumerate(raw_data.get('temperature', [])):
             data['temperature'][i],changed = check_temperature(
                 index=i, temperature=temperature, config=config,
                 maximum_relative_error=maximum_relative_error)
@@ -262,10 +269,9 @@ def analyze_all(config, data, raw_data, maximum_relative_error=1e-5,
                     group, 'temperature/{}'.format(i))
                 _temperature_save(
                     group=temperature_group, processed=data['temperature'][i])
-        if len(data.get('raw', {}).get('vibration', [])
-               ) != len(data['vibration']):
-            vibrations_changed = True
-        for i,vibration in enumerate(raw_data['vibration']):
+        vibrations_changed = len(data['raw']['vibration']) != len(
+            data['vibration'])
+        for i,vibration in enumerate(raw_data.get('vibration', [])):
             data['vibration'][i],changed = check_vibration(
                     index=i, vibration=vibration, config=config,
                     deflection_channel_config=input_config, plot=plot,
@@ -288,18 +294,20 @@ def analyze_all(config, data, raw_data, maximum_relative_error=1e-5,
             if vibrations_changed:
                 save_results(
                     group=calibration_group, vibration=data['vibration'])
-        if len(raw_data['bump']) != len(data['bump']):
+        if len(raw_data.get('bump', [])) != len(data['bump']):
             raise ValueError(
                 'not enough raw bump data: {} of {}'.format(
-                    len(raw_data['bump']), len(data['bump'])))
-        if len(raw_data['temperature']) != len(data['temperature']):
+                    len(raw_data.get('bump', [])), len(data['bump'])))
+        if len(raw_data.get('temperature', [])) != len(data['temperature']):
             raise ValueError(
                 'not enough raw temperature data: {} of {}'.format(
-                    len(raw_data['temperature']), len(data['temperature'])))
+                    len(raw_data.get('temperature', [])),
+                    len(data['temperature'])))
         if len(raw_data['vibration']) != len(data['vibration']):
             raise ValueError(
                 'not enough raw vibration data: {} of {}'.format(
-                    len(raw_data['vibration']), len(data['vibration'])))
+                    len(raw_data.get('vibration', [])),
+                    len(data['vibration'])))
         k,k_s,changed = check_calibration(
             k=data.get('processed', {}).get('spring_constant', None),
             k_s=data.get('processed', {}).get(
@@ -317,9 +325,9 @@ def analyze_all(config, data, raw_data, maximum_relative_error=1e-5,
         if f:
             f.close()
     if plot:
-        _plot(bumps=data['raw']['bump'],
-             temperatures=data['raw']['temperature'],
-             vibrations=data['raw']['vibration'])
+        _plot(bumps=data['bump'],
+              temperatures=data['temperature'],
+              vibrations=data['vibration'])
     return (k, k_s)
 
 def check_bump(index, bump, config=None, maximum_relative_error=0, **kwargs):