# -*- coding: utf-8 -*-
"""
Testing utilities for ObsPy.
:copyright:
The ObsPy Development Team (devs@obspy.org)
:license:
GNU Lesser General Public License, Version 3
(https://www.gnu.org/copyleft/lesser.html)
"""
import difflib
import doctest
import inspect
import io
import os
import re
import warnings
import numpy as np
from lxml import etree
MODULE_TEST_SKIP_CHECKS = {}
[docs]def compare_xml_strings(doc1, doc2):
"""
Simple helper function to compare two XML strings.
:type doc1: str
:type doc2: str
"""
# Compat py2k and py3k
try:
doc1 = doc1.encode()
doc2 = doc2.encode()
except Exception:
pass
obj1 = etree.fromstring(doc1).getroottree()
obj2 = etree.fromstring(doc2).getroottree()
buf = io.BytesIO()
obj1.write_c14n(buf)
buf.seek(0, 0)
str1 = buf.read().decode()
str1 = [_i.strip() for _i in str1.splitlines()]
buf = io.BytesIO()
obj2.write_c14n(buf)
buf.seek(0, 0)
str2 = buf.read().decode()
str2 = [_i.strip() for _i in str2.splitlines()]
unified_diff = difflib.unified_diff(str1, str2)
err_msg = "\n".join(unified_diff)
if err_msg: # pragma: no cover
raise AssertionError("Strings are not equal.\n" + err_msg)
[docs]def remove_unique_ids(xml_string, remove_creation_time=False):
"""
Removes unique ID parts of e.g. 'publicID="..."' attributes from xml
strings.
:type xml_string: str
:param xml_string: xml string to process
:type remove_creation_time: bool
:param xml_string: controls whether to remove 'creationTime' tags or not.
:rtype: str
"""
prefixes = ["id", "publicID", "pickID", "originID", "preferredOriginID",
"preferredMagnitudeID", "preferredFocalMechanismID",
"referenceSystemID", "methodID", "earthModelID",
"triggeringOriginID", "derivedOriginID", "momentMagnitudeID",
"greensFunctionID", "filterID", "amplitudeID",
"stationMagnitudeID", "earthModelID", "slownessMethodID",
"pickReference", "amplitudeReference"]
if remove_creation_time:
prefixes.append("creationTime")
for prefix in prefixes:
xml_string = re.sub("%s='.*?'" % prefix, '%s=""' % prefix, xml_string)
xml_string = re.sub('%s=".*?"' % prefix, '%s=""' % prefix, xml_string)
xml_string = re.sub("<%s>.*?</%s>" % (prefix, prefix),
'<%s/>' % prefix, xml_string)
return xml_string
[docs]def get_all_py_files():
"""
Return a list with full absolute paths to all .py files in ObsPy file tree.
:rtype: list[str]
"""
util_dir = os.path.abspath(inspect.getfile(inspect.currentframe()))
obspy_dir = os.path.dirname(os.path.dirname(os.path.dirname(util_dir)))
py_files = set()
# Walk the obspy directory structure
for dirpath, _, filenames in os.walk(obspy_dir):
py_files.update([os.path.abspath(os.path.join(dirpath, i)) for i in
filenames if i.endswith(".py")])
return sorted(py_files)
[docs]class WarningsCapture(object):
"""
Try hard to capture all warnings.
Aims to be a reliable drop-in replacement for built-in
warnings.catch_warnings() context manager.
Based on pytest's _DeprecatedCallContext context manager.
"""
[docs] def __enter__(self):
self.captured_warnings = []
self._old_warn = warnings.warn
self._old_warn_explicit = warnings.warn_explicit
warnings.warn_explicit = self._warn_explicit
warnings.warn = self._warn
return self
[docs] def _warn_explicit(self, message, category, *args, **kwargs):
self.captured_warnings.append(
warnings.WarningMessage(message=category(message),
category=category,
filename="", lineno=0))
[docs] def _warn(self, message, category=Warning, *args, **kwargs):
if isinstance(message, Warning):
self.captured_warnings.append(
warnings.WarningMessage(
message=category(message), category=category or Warning,
filename="", lineno=0))
else:
self.captured_warnings.append(
warnings.WarningMessage(
message=category(message), category=category,
filename="", lineno=0))
[docs] def __exit__(self, exc_type, exc_val, exc_tb):
warnings.warn_explicit = self._old_warn_explicit
warnings.warn = self._old_warn
[docs] def __len__(self):
return len(self.captured_warnings)
[docs] def __getitem__(self, key):
return self.captured_warnings[key]
[docs]def create_diverse_catalog():
"""
Create a catalog with a single event that has many features.
Uses most the event related classes.
"""
# imports are here in order to avoid circular import issues
import obspy.core.event as ev
from obspy import UTCDateTime, Catalog
# local dict for storing state
state = dict(time=UTCDateTime('2016-05-04T12:00:01'))
def _create_event():
event = ev.Event(
event_type='mining explosion',
event_descriptions=[_get_event_description()],
picks=[_create_pick()],
origins=[_create_origins()],
station_magnitudes=[_get_station_mag()],
magnitudes=[_create_magnitudes()],
amplitudes=[_get_amplitudes()],
focal_mechanisms=[_get_focal_mechanisms()],
)
# set preferred origin, focal mechanism, magnitude
preferred_objects = dict(
origin=event.origins[-1].resource_id,
focal_mechanism=event.focal_mechanisms[-1].resource_id,
magnitude=event.magnitudes[-1].resource_id,
)
for item, value in preferred_objects.items():
setattr(event, 'preferred_' + item + '_id', value)
event.scope_resource_ids()
return event
def _create_pick():
# setup some of the classes
creation = ev.CreationInfo(
agency='SwanCo',
author='Indago',
creation_time=UTCDateTime(),
version='10.10',
author_url=ev.ResourceIdentifier('smi:local/me.com'),
)
pick = ev.Pick(
time=state['time'],
comments=[ev.Comment(x) for x in 'BOB'],
evaluation_mode='manual',
evaluation_status='final',
creation_info=creation,
phase_hint='P',
polarity='positive',
onset='emergent',
back_azimith_errors={"uncertainty": 10},
slowness_method_id=ev.ResourceIdentifier('smi:local/slow'),
backazimuth=122.1,
horizontal_slowness=12,
method_id=ev.ResourceIdentifier(),
horizontal_slowness_errors={'uncertainty': 12},
filter_id=ev.ResourceIdentifier(),
waveform_id=ev.WaveformStreamID('UU', 'FOO', '--', 'HHZ'),
)
state['pick_id'] = pick.resource_id
return pick
def _create_origins():
ori = ev.Origin(
resource_id=ev.ResourceIdentifier('smi:local/First'),
time=UTCDateTime('2016-05-04T12:00:00'),
time_errors={'uncertainty': .01},
longitude=-111.12525,
longitude_errors={'uncertainty': .020},
latitude=47.48589325,
latitude_errors={'uncertainty': .021},
depth=2.123,
depth_errors={'uncertainty': 1.22},
depth_type='from location',
time_fixed=False,
epicenter_fixed=False,
reference_system_id=ev.ResourceIdentifier(),
method_id=ev.ResourceIdentifier(),
earth_model_id=ev.ResourceIdentifier(),
arrivals=[_get_arrival()],
composite_times=[_get_composite_times()],
quality=_get_origin_quality(),
origin_type='hypocenter',
origin_uncertainty=_get_origin_uncertainty(),
region='US',
evaluation_mode='manual',
evaluation_status='final',
)
state['origin_id'] = ori.resource_id
return ori
def _get_arrival():
return ev.Arrival(
resource_id=ev.ResourceIdentifier('smi:local/Ar1'),
pick_id=state['pick_id'],
phase='P',
time_correction=.2,
azimuth=12,
distance=10,
takeoff_angle=15,
takeoff_angle_errors={'uncertainty': 10.2},
time_residual=.02,
horizontal_slowness_residual=12.2,
backazimuth_residual=12.2,
time_weight=.23,
horizontal_slowness_weight=12,
backazimuth_weight=12,
earth_model_id=ev.ResourceIdentifier(),
commens=[ev.Comment(x) for x in 'Nothing'],
)
def _get_composite_times():
return ev.CompositeTime(
year=2016,
year_errors={'uncertainty': 0},
month=5,
month_errors={'uncertainty': 0},
day=4,
day_errors={'uncertainty': 0},
hour=0,
hour_errors={'uncertainty': 0},
minute=0,
minute_errors={'uncertainty': 0},
second=0,
second_errors={'uncertainty': .01}
)
def _get_origin_quality():
return ev.OriginQuality(
associate_phase_count=1,
used_phase_count=1,
associated_station_count=1,
used_station_count=1,
depth_phase_count=1,
standard_error=.02,
azimuthal_gap=.12,
ground_truth_level='GT0',
)
def _get_origin_uncertainty():
return ev.OriginUncertainty(
horizontal_uncertainty=1.2,
min_horizontal_uncertainty=.12,
max_horizontal_uncertainty=2.2,
confidence_ellipsoid=_get_confidence_ellipsoid(),
preferred_description="uncertainty ellipse",
)
def _get_confidence_ellipsoid():
return ev.ConfidenceEllipsoid(
semi_major_axis_length=12,
semi_minor_axis_length=12,
major_axis_plunge=12,
major_axis_rotation=12,
)
def _create_magnitudes():
return ev.Magnitude(
resource_id=ev.ResourceIdentifier(),
mag=5.5,
mag_errors={'uncertainty': .01},
magnitude_type='Mw',
origin_id=state['origin_id'],
station_count=1,
station_magnitude_contributions=[_get_station_mag_contrib()],
)
def _get_station_mag():
station_mag = ev.StationMagnitude(
mag=2.24,
)
state['station_mag_id'] = station_mag.resource_id
return station_mag
def _get_station_mag_contrib():
return ev.StationMagnitudeContribution(
station_magnitude_id=state['station_mag_id'],
)
def _get_event_description():
return ev.EventDescription(
text='some text about the EQ',
type='earthquake name',
)
def _get_amplitudes():
return ev.Amplitude(
generic_amplitude=.0012,
type='A',
unit='m',
period=1,
time_window=_get_timewindow(),
pick_id=state['pick_id'],
scalling_time=state['time'],
mangitude_hint='ML',
scaling_time_errors=ev.QuantityError(uncertainty=42.0),
)
def _get_timewindow():
return ev.TimeWindow(
begin=1.2,
end=2.2,
reference=UTCDateTime('2016-05-04T12:00:00'),
)
def _get_focal_mechanisms():
return ev.FocalMechanism(
nodal_planes=_get_nodal_planes(),
principal_axis=_get_principal_axis(),
azimuthal_gap=12,
station_polarity_count=12,
misfit=.12,
station_distribution_ratio=.12,
moment_tensor=_get_moment_tensor(),
)
def _get_nodal_planes():
return ev.NodalPlanes(
nodal_plane_1=ev.NodalPlane(strike=12, dip=2, rake=12),
nodal_plane_2=ev.NodalPlane(strike=12, dip=2, rake=12),
preferred_plane=2,
)
def _get_principal_axis():
return ev.PrincipalAxes(
t_axis=15,
p_axis=15,
n_axis=15,
)
def _get_moment_tensor():
return ev.MomentTensor(
scalar_moment=12213,
tensor=_get_tensor(),
variance=12.23,
variance_reduction=98,
double_couple=.22,
clvd=.55,
iso=.33,
source_time_function=_get_source_time_function(),
data_used=[_get_data_used()],
method_id=ev.ResourceIdentifier(),
inversion_type='general',
)
def _get_tensor():
return ev.Tensor(
m_rr=12,
m_rr_errors={'uncertainty': .01},
m_tt=12,
m_pp=12,
m_rt=12,
m_rp=12,
m_tp=12,
)
def _get_source_time_function():
return ev.SourceTimeFunction(
type='triangle',
duration=.12,
rise_time=.33,
decay_time=.23,
)
def _get_data_used():
return ev.DataUsed(
wave_type='body waves',
station_count=12,
component_count=12,
shortest_period=1,
longest_period=20,
)
events = [_create_event()]
return Catalog(events=events)
[docs]def setup_context_testcase(test_case, cm):
"""
Use a contextmanager to set up a unittest test case.
Inspired by Ned Batchelder's recipe found here: goo.gl/8TBJ7s.
:param test_case:
An instance of unittest.TestCase
:param cm:
Any instances which implements the context manager protocol,
ie its class definition implements __enter__ and __exit__ methods.
"""
val = cm.__enter__()
test_case.addCleanup(cm.__exit__, None, None, None)
return val
[docs]def streams_almost_equal(st1, st2, default_stats=True, rtol=1e-05, atol=1e-08,
equal_nan=True):
"""
Return True if two streams are almost equal.
:param st1: The first :class:`~obspy.core.stream.Stream` object.
:param st2: The second :class:`~obspy.core.stream.Stream` object.
:param default_stats:
If True only compare the default stats on the traces, such as seed
identification codes, start/end times, sampling_rates, etc. If
False also compare extra stats attributes such as processing and
format specific information.
:param rtol: The relative tolerance parameter passed to
:func:`~numpy.allclose` for comparing time series.
:param atol: The absolute tolerance parameter passed to
:func:`~numpy.allclose` for comparing time series.
:param equal_nan:
If ``True`` NaNs are evaluated equal when comparing the time
series.
:return: bool
.. rubric:: Example
1) Changes to the non-default parameters of the
:class:`~obspy.core.trace.Stats` objects of the stream's contained
:class:`~obspy.core.trace.Trace` objects will cause the streams to
be considered unequal, but they will be considered almost equal.
>>> from obspy import read
>>> st1 = read()
>>> st2 = read()
>>> # The traces should, of course, be equal.
>>> assert st1 == st2
>>> # Perform detrending on st1 twice so processing stats differ.
>>> st1 = st1.detrend('linear')
>>> st1 = st1.detrend('linear')
>>> st2 = st2.detrend('linear')
>>> # The traces are no longer equal, but are almost equal.
>>> assert st1 != st2
>>> assert streams_almost_equal(st1, st2)
2) Slight differences in each trace's data will cause the streams
to be considered unequal, but they will be almost equal if the
differences don't exceed the limits set by the ``rtol`` and
``atol`` parameters.
>>> from obspy import read
>>> st1 = read()
>>> st2 = read()
>>> # Perturb the trace data in st2 slightly.
>>> for tr in st2:
... tr.data *= (1 + 1e-6)
>>> # The streams are no longer equal.
>>> assert st1 != st2
>>> # But they are almost equal.
>>> assert streams_almost_equal(st1, st2)
>>> # Unless, of course, there is a large change.
>>> st1[0].data *= 10
>>> assert not streams_almost_equal(st1, st2)
"""
from obspy.core.stream import Stream
# Return False if both objects are not streams or not the same length.
are_streams = isinstance(st1, Stream) and isinstance(st2, Stream)
if not are_streams or not len(st1) == len(st2):
return False
# Kwargs to pass trace_almost_equal.
tr_kwargs = dict(default_stats=default_stats, rtol=rtol, atol=atol,
equal_nan=equal_nan)
# Ensure the streams are sorted (as done with the __equal__ method)
st1_sorted = st1.select()
st1_sorted.sort()
st2_sorted = st2.select()
st2_sorted.sort()
# Iterate over sorted trace pairs and determine if they are almost equal.
for tr1, tr2 in zip(st1_sorted, st2_sorted):
if not traces_almost_equal(tr1, tr2, **tr_kwargs):
return False # If any are not almost equal return None.
return True
[docs]def traces_almost_equal(tr1, tr2, default_stats=True, rtol=1e-05, atol=1e-08,
equal_nan=True):
"""
Return True if the two traces are almost equal.
:param tr1: The first :class:`~obspy.core.trace.Trace` object.
:param tr2: The second :class:`~obspy.core.trace.Trace` object.
:param default_stats:
If True only compare the default stats on the traces, such as seed
identification codes, start/end times, sampling_rates, etc. If
False also compare extra stats attributes such as processing and
format specific information.
:param rtol: The relative tolerance parameter passed to
:func:`~numpy.allclose` for comparing time series.
:param atol: The absolute tolerance parameter passed to
:func:`~numpy.allclose` for comparing time series.
:param equal_nan:
If ``True`` NaNs are evaluated equal when comparing the time
series.
:return: bool
"""
from obspy.core.trace import Trace
# If other isnt a trace, or data is not the same len return False.
if not isinstance(tr2, Trace) or len(tr1.data) != len(tr2.data):
return False
# First compare the array values
try: # Use equal_nan if available
all_close = np.allclose(tr1.data, tr2.data, rtol=rtol,
atol=atol, equal_nan=equal_nan)
except TypeError:
# This happens on very old versions of numpy. Essentially
# we just need to handle NaN detection on our own, if equal_nan.
is_close = np.isclose(tr1.data, tr2.data, rtol=rtol, atol=atol)
if equal_nan:
isnan = np.isnan(tr1.data) & np.isnan(tr2.data)
else:
isnan = np.zeros(tr1.data.shape).astype(bool)
all_close = np.all(isnan | is_close)
# Then compare the stats objects
stats1 = _make_stats_dict(tr1, default_stats)
stats2 = _make_stats_dict(tr2, default_stats)
return all_close and stats1 == stats2
[docs]def _make_stats_dict(tr, default_stats):
"""
Return a dict of stats from trace optionally including processing.
"""
from obspy.core.trace import Stats
if not default_stats:
return dict(tr.stats)
return {i: tr.stats[i] for i in Stats.defaults}
if __name__ == '__main__':
doctest.testmod(exclude_empty=True)