Skip to content

Commit a11021c

Browse files
GuyStenpaulromano
andauthored
Consistent XML parsing using functions from _xml module (#3517)
Co-authored-by: Paul Romano <paul.k.romano@gmail.com>
1 parent e36c0ae commit a11021c

22 files changed

+313
-321
lines changed

openmc/_xml.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def get_text(elem, name, default=None):
6464

6565

6666

67-
def get_elem_tuple(elem, name, dtype=int):
68-
"""Helper function to get a tuple of values from an elem
67+
def get_elem_list(elem, name, dtype=int):
68+
"""Helper function to get a list of values from an elem
6969
7070
Parameters
7171
----------
@@ -78,9 +78,9 @@ def get_elem_tuple(elem, name, dtype=int):
7878
7979
Returns
8080
-------
81-
tuple of dtype
82-
Data read from the tuple
81+
list of dtype
82+
Data read from the list
8383
"""
84-
subelem = elem.find(name)
85-
if subelem is not None:
86-
return tuple([dtype(x) for x in subelem.text.split()])
84+
text = get_text(elem, name)
85+
if text is not None:
86+
return [dtype(x) for x in text.split()]

openmc/cell.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import openmc
1010
import openmc.checkvalue as cv
11-
from ._xml import get_text
11+
from ._xml import get_elem_list, get_text
1212
from .mixin import IDManagerMixin
1313
from .plots import add_plot_params
1414
from .region import Region, Complement
@@ -689,9 +689,8 @@ def from_xml_element(cls, elem, surfaces, materials, get_universe):
689689
c = cls(cell_id, name)
690690

691691
# Assign material/distributed materials or fill
692-
mat_text = get_text(elem, 'material')
693-
if mat_text is not None:
694-
mat_ids = mat_text.split()
692+
mat_ids = get_elem_list(elem, 'material', str)
693+
if mat_ids is not None:
695694
if len(mat_ids) > 1:
696695
c.fill = [materials[i] for i in mat_ids]
697696
else:
@@ -706,19 +705,18 @@ def from_xml_element(cls, elem, surfaces, materials, get_universe):
706705
c.region = Region.from_expression(region, surfaces)
707706

708707
# Check for other attributes
709-
t = get_text(elem, 'temperature')
710-
if t is not None:
711-
if ' ' in t:
712-
c.temperature = [float(t_i) for t_i in t.split()]
708+
temperature = get_elem_list(elem, 'temperature', float)
709+
if temperature is not None:
710+
if len(temperature) > 1:
711+
c.temperature = temperature
713712
else:
714-
c.temperature = float(t)
713+
c.temperature = temperature[0]
715714
v = get_text(elem, 'volume')
716715
if v is not None:
717716
c.volume = float(v)
718717
for key in ('temperature', 'rotation', 'translation'):
719-
value = get_text(elem, key)
720-
if value is not None:
721-
values = [float(x) for x in value.split()]
718+
values = get_elem_list(elem, key, float)
719+
if values is not None:
722720
if key == 'rotation' and len(values) == 9:
723721
values = np.array(values).reshape(3, 3)
724722
setattr(c, key, values)

openmc/dagmc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import openmc
1010
import openmc.checkvalue as cv
11-
from ._xml import get_text
11+
from ._xml import get_elem_list, get_text
1212
from .checkvalue import check_type, check_value
1313
from .surface import _BOUNDARY_TYPES
1414
from .bounding_box import BoundingBox
@@ -468,8 +468,8 @@ def from_xml_element(cls, elem, mats = None):
468468
if name is not None:
469469
out.name = name
470470

471-
out.auto_geom_ids = bool(elem.get('auto_geom_ids'))
472-
out.auto_mat_ids = bool(elem.get('auto_mat_ids'))
471+
out.auto_geom_ids = bool(get_text(elem, "auto_geom_ids"))
472+
out.auto_mat_ids = bool(get_text(elem, "auto_mat_ids"))
473473

474474
el_mat_override = elem.find('material_overrides')
475475
if el_mat_override is not None:
@@ -480,7 +480,7 @@ def from_xml_element(cls, elem, mats = None):
480480
out._material_overrides = {}
481481
for elem in el_mat_override.findall('cell_override'):
482482
cell_id = int(get_text(elem, 'id'))
483-
mat_ids = get_text(elem, 'material_ids').split(' ')
483+
mat_ids = get_elem_list(elem, "material_ids", str) or []
484484
mat_objs = [mats[mat_id] for mat_id in mat_ids]
485485
out._material_overrides[cell_id] = mat_objs
486486

openmc/data/library.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import lxml.etree as ET
66

77
import openmc
8-
from openmc._xml import clean_indentation
8+
from openmc._xml import get_elem_list, get_text, clean_indentation
99

1010

1111
class DataLibrary(list):
@@ -172,17 +172,17 @@ def from_xml(cls, path=None):
172172
directory = os.path.dirname(path)
173173

174174
for lib_element in root.findall('library'):
175-
filename = os.path.join(directory, lib_element.attrib['path'])
176-
filetype = lib_element.attrib['type']
177-
materials = lib_element.attrib['materials'].split()
175+
filename = os.path.join(directory, get_text(lib_element, "path"))
176+
filetype = get_text(lib_element, "type")
177+
materials = get_elem_list(lib_element, "materials", str) or []
178178
library = {'path': filename, 'type': filetype,
179179
'materials': materials}
180180
data.libraries.append(library)
181181

182182
# get depletion chain data
183183
dep_node = root.find("depletion_chain")
184184
if dep_node is not None:
185-
filename = os.path.join(directory, dep_node.attrib['path'])
185+
filename = os.path.join(directory, get_text(dep_node, "path"))
186186
library = {'path': filename, 'type': 'depletion_chain',
187187
'materials': []}
188188
data.libraries.append(library)

openmc/deplete/chain.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from openmc.data import gnds_name, zam
2323
from openmc.exceptions import DataError
2424
from .nuclide import FissionYieldDistribution, Nuclide
25+
from .._xml import get_text
2526
import openmc.data
2627

2728

@@ -553,7 +554,7 @@ def from_xml(cls, filename, fission_q=None):
553554
root = ET.parse(str(filename))
554555

555556
for i, nuclide_elem in enumerate(root.findall('nuclide')):
556-
this_q = fission_q.get(nuclide_elem.get("name"))
557+
this_q = fission_q.get(get_text(nuclide_elem, "name"))
557558

558559
nuc = Nuclide.from_xml(nuclide_elem, root, this_q)
559560
chain.add_nuclide(nuc)

openmc/deplete/nuclide.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from openmc.checkvalue import check_type
1616
from openmc.stats import Univariate
17+
from .._xml import get_elem_list, get_text
1718

1819
__all__ = [
1920
"DecayTuple", "ReactionTuple", "Nuclide", "FissionYield",
@@ -225,38 +226,39 @@ def from_xml(cls, element, root=None, fission_q=None):
225226
226227
"""
227228
nuc = cls()
228-
nuc.name = element.get('name')
229+
nuc.name = get_text(element, "name")
229230

230231
# Check for half-life
231-
if 'half_life' in element.attrib:
232-
nuc.half_life = float(element.get('half_life'))
233-
nuc.decay_energy = float(element.get('decay_energy', '0'))
232+
half_life = get_text(element, "half_life")
233+
if half_life is not None:
234+
nuc.half_life = float(half_life)
235+
nuc.decay_energy = float(get_text(element, "decay_energy", 0.0))
234236

235237
# Check for decay paths
236238
for decay_elem in element.iter('decay'):
237-
d_type = decay_elem.get('type')
238-
target = decay_elem.get('target')
239+
d_type = get_text(decay_elem, "type")
240+
target = get_text(decay_elem, "target")
239241
if target is not None and target.lower() == "nothing":
240242
target = None
241-
branching_ratio = float(decay_elem.get('branching_ratio'))
243+
branching_ratio = float(get_text(decay_elem, "branching_ratio"))
242244
nuc.decay_modes.append(DecayTuple(d_type, target, branching_ratio))
243245

244246
# Check for sources
245247
for src_elem in element.iter('source'):
246-
particle = src_elem.get('particle')
248+
particle = get_text(src_elem, "particle")
247249
distribution = Univariate.from_xml_element(src_elem)
248250
nuc.sources[particle] = distribution
249251

250252
# Check for reaction paths
251253
for reaction_elem in element.iter('reaction'):
252-
r_type = reaction_elem.get('type')
253-
Q = float(reaction_elem.get('Q', '0'))
254-
branching_ratio = float(reaction_elem.get('branching_ratio', '1'))
254+
r_type = get_text(reaction_elem, "type")
255+
Q = float(get_text(reaction_elem, "Q", 0.0))
256+
branching_ratio = float(get_text(reaction_elem, "branching_ratio", 1.0))
255257

256258
# If the type is not fission, get target and Q value, otherwise
257259
# just set null values
258260
if r_type != 'fission':
259-
target = reaction_elem.get('target')
261+
target = get_text(reaction_elem, "target")
260262
if target is not None and target.lower() == "nothing":
261263
target = None
262264
else:
@@ -271,7 +273,7 @@ def from_xml(cls, element, root=None, fission_q=None):
271273
fpy_elem = element.find('neutron_fission_yields')
272274
if fpy_elem is not None:
273275
# Check for use of FPY from other nuclide
274-
parent = fpy_elem.get('parent')
276+
parent = get_text(fpy_elem, "parent")
275277
if parent is not None:
276278
assert root is not None
277279
fpy_elem = root.find(
@@ -529,9 +531,9 @@ def from_xml_element(cls, element):
529531
"""
530532
all_yields = {}
531533
for yield_elem in element.iter("fission_yields"):
532-
energy = float(yield_elem.get("energy"))
533-
products = yield_elem.find("products").text.split()
534-
yields = map(float, yield_elem.find("data").text.split())
534+
energy = float(get_text(yield_elem, "energy"))
535+
products = get_elem_list(yield_elem, "products", str) or []
536+
yields = get_elem_list(yield_elem, "data", float) or []
535537
# Get a map of products to their corresponding yield
536538
all_yields[energy] = dict(zip(products, yields))
537539

openmc/filter.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .mixin import IDManagerMixin
1818
from .surface import Surface
1919
from .universe import UniverseBase
20-
from ._xml import get_text
20+
from ._xml import get_elem_list, get_text
2121

2222

2323
_FILTER_TYPES = (
@@ -259,17 +259,15 @@ def from_xml_element(cls, elem, **kwargs):
259259
Filter object
260260
261261
"""
262-
filter_type = elem.get('type')
263-
if filter_type is None:
264-
filter_type = elem.find('type').text
262+
filter_type = get_text(elem, "type")
265263

266264
# If the filter type matches this class's short_name, then
267265
# there is no overridden from_xml_element method
268266
if filter_type == cls.short_name.lower():
269267
# Get bins from element -- the default here works for any filters
270268
# that just store a list of bins that can be represented as integers
271-
filter_id = int(elem.get('id'))
272-
bins = [int(x) for x in get_text(elem, 'bins').split()]
269+
filter_id = int(get_text(elem, "id"))
270+
bins = get_elem_list(elem, "bins", int) or []
273271
return cls(bins, filter_id=filter_id)
274272

275273
# Search through all subclasses and find the one matching the HDF5
@@ -701,8 +699,8 @@ def to_xml_element(self):
701699

702700
@classmethod
703701
def from_xml_element(cls, elem, **kwargs):
704-
filter_id = int(elem.get('id'))
705-
bins = [int(x) for x in get_text(elem, 'bins').split()]
702+
filter_id = int(get_text(elem, "id"))
703+
bins = get_elem_list(elem, "bins", int) or []
706704
cell_instances = list(zip(bins[::2], bins[1::2]))
707705
return cls(cell_instances, filter_id=filter_id)
708706

@@ -784,8 +782,8 @@ def from_hdf5(cls, group, **kwargs):
784782

785783
@classmethod
786784
def from_xml_element(cls, elem, **kwargs):
787-
filter_id = int(elem.get('id'))
788-
bins = get_text(elem, 'bins').split()
785+
filter_id = int(get_text(elem, "id"))
786+
bins = get_elem_list(elem, "bins", str) or []
789787
return cls(bins, filter_id=filter_id)
790788

791789

@@ -1004,12 +1002,12 @@ def to_xml_element(self):
10041002
def from_xml_element(cls, elem: ET.Element, **kwargs) -> MeshFilter:
10051003
mesh_id = int(get_text(elem, 'bins'))
10061004
mesh_obj = kwargs['meshes'][mesh_id]
1007-
filter_id = int(elem.get('id'))
1005+
filter_id = int(get_text(elem, "id"))
10081006
out = cls(mesh_obj, filter_id=filter_id)
10091007

1010-
translation = elem.get('translation')
1008+
translation = get_elem_list(elem, "translation", float) or []
10111009
if translation:
1012-
out.translation = [float(x) for x in translation.split()]
1010+
out.translation = translation
10131011
return out
10141012

10151013

@@ -1149,16 +1147,16 @@ def to_xml_element(self):
11491147

11501148
@classmethod
11511149
def from_xml_element(cls, elem: ET.Element, **kwargs) -> MeshMaterialFilter:
1152-
filter_id = int(elem.get('id'))
1153-
mesh_id = int(elem.get('mesh'))
1150+
filter_id = int(get_text(elem, "id"))
1151+
mesh_id = int(get_text(elem, "mesh"))
11541152
mesh_obj = kwargs['meshes'][mesh_id]
1155-
bins = [int(x) for x in get_text(elem, 'bins').split()]
1153+
bins = get_elem_list(elem, "bins", int) or []
11561154
bins = list(zip(bins[::2], bins[1::2]))
11571155
out = cls(mesh_obj, bins, filter_id=filter_id)
11581156

1159-
translation = elem.get('translation')
1157+
translation = get_elem_list(elem, "translation", float) or []
11601158
if translation:
1161-
out.translation = [float(x) for x in translation.split()]
1159+
out.translation = translation
11621160
return out
11631161

11641162
@classmethod
@@ -1557,8 +1555,8 @@ def to_xml_element(self):
15571555

15581556
@classmethod
15591557
def from_xml_element(cls, elem, **kwargs):
1560-
filter_id = int(elem.get('id'))
1561-
bins = [float(x) for x in get_text(elem, 'bins').split()]
1558+
filter_id = int(get_text(elem, "id"))
1559+
bins = get_elem_list(elem, "bins", float) or []
15621560
return cls(bins, filter_id=filter_id)
15631561

15641562

@@ -2447,12 +2445,13 @@ def to_xml_element(self):
24472445

24482446
@classmethod
24492447
def from_xml_element(cls, elem, **kwargs):
2450-
filter_id = int(elem.get('id'))
2451-
energy = [float(x) for x in get_text(elem, 'energy').split()]
2452-
y = [float(x) for x in get_text(elem, 'y').split()]
2448+
filter_id = int(get_text(elem, "id"))
2449+
energy = get_elem_list(elem, "energy", float) or []
2450+
y = get_elem_list(elem, "y", float) or []
24532451
out = cls(energy, y, filter_id=filter_id)
2454-
if elem.find('interpolation') is not None:
2455-
out.interpolation = elem.find('interpolation').text
2452+
interpolation = get_text(elem, "interpolation")
2453+
if interpolation is not None:
2454+
out.interpolation = interpolation
24562455
return out
24572456

24582457
def can_merge(self, other):

0 commit comments

Comments
 (0)