.. module:: genericSMS
:synopsis: This is a base class for describing Simplified Model Topologies using a rooted tree syntax.
.. moduleauthor:: Andre Lessa <lessa.a.p@gmail.com>
from smodels.base.exceptions import SModelSBaseError as SModelSError
from smodels.base.inclusiveObjects import InclusiveValue
from collections import OrderedDict
from itertools import product
[docs]class GenericSMS(object):
A generic class for describing and manipulating
Simplified Model Topologies based on graphs concepts.
def __init__(self):
Initialize basic attributes.
self._canonName = None
self._rootIndex = None
self._successors = OrderedDict() # Stores the nodes and their successors (daughters)
self._predecessors = {} # Stores the nodes and their predecessors (parents)
self._nodesMapping = {} # Stores the nodeIndex->node object mapping
self._nodeCanonNames = {} # Stores the canonical names for the nodes
self._finalStates = {} # Stores the final states of each node
self._sorted = False # Tag SMS as sorted or not
def __hash__(self):
return object.__hash__(self)
def __repr__(self):
Returns the string representation of the tree.
return str(self)
def __str__(self):
Returns a string representing the process
described by the tree.
return self.treeToString()
def __getattr__(self, attr):
If the attribute has not been defined for self
try to fetch it from its nodes
:param attr: Attribute name
:return: Attribute value
# If calling another special method, return default (required for pickling)
if (attr.startswith('__') and attr.endswith('__')) or attr in dir(self):
return self.__getattribute__(attr)
val = [getattr(node, attr) if node is not self.root else None
for node in self.nodes]
return val
except AttributeError:
raise AttributeError("Neither SMS nor nodes have attribute ``%s''" % attr)
[docs] def add_node(self, node, nodeIndex=None):
Adds a node object to the tree. If nodeIndex is None,
the node index will be automatically assigned.
:param node: ParticleNode object
:param nodeIndex: The index for the ParticleNode. It must not
match any other indices already in the tree.
:return: The node index for the newly added node
if nodeIndex is None:
if not self._successors:
nodeIndex = 0
nodeIndex = max(self.nodeIndices)+1
elif nodeIndex in self._successors:
raise SModelSError("Trying to add a node with a nodeIndex already in the tree.")
self._successors[nodeIndex] = []
self._nodesMapping[nodeIndex] = node
return nodeIndex
[docs] def add_nodes_from(self, nodes):
Adds a list of nodes to the Tree.
:param nodes: List of ParticleNode objects
:return: A list of node indices for the newly added nodes
nodeIndices = []
for node in nodes:
return nodeIndices
[docs] def remove_node(self, nodeIndex):
Removes a node from the tree if the nodeIndex is in the tree.
The node is removed as well as its appearence in any edges.
:param nodeIndex: Node index
if nodeIndex in self._successors:
if nodeIndex in self._nodeCanonNames:
for nodeA, daughtersA in self._successors.items():
if nodeIndex not in daughtersA:
daughtersA = [d for d in daughtersA[:] if d != nodeIndex]
self._successors[nodeA] = daughtersA[:]
if nodeIndex in self._predecessors:
for nodeA, momA in list(self._predecessors.items()):
if momA == nodeIndex:
if nodeIndex in self._finalStates:
[docs] def remove_nodes_from(self, nodeIndices):
Removes a list of nodes from the Tree.
:param nodeIndices: List of node indices
for nodeIndex in nodeIndices:
[docs] def add_edge(self, nodeIndexA, nodeIndexB):
Adds a directed edge to existing nodes in the Tree (nodeA -> nodeB).
:param nodeIndexA: Index for node A
:param nodeIndexB: Index for node B
self._predecessors[nodeIndexB] = nodeIndexA
[docs] def add_edges_from(self, edges):
Adds a list of directed edges to the Tree.
:param edges: List of tuples containing node indices
(e.g. [(nodeIndexA,nodeIndexB),(nodeIndexA,nodeIndexC),...])
for edge in edges:
self.add_edge(edge[0], edge[1])
[docs] def remove_edge(self, nodeIndexA, nodeIndexB):
Removes an edge from the tree if the edge
(nodeIndexA -> nodeIndexB) is in the tree.
:param nodeIndexA: Index for node A
:param nodeIndexB: Index for node B
if nodeIndexA in self._successors:
daughters = self._successors[nodeIndexA]
daughters = [d for d in daughters if d != nodeIndexB]
self._successors[nodeIndexA] = daughters
if nodeIndexB in self._predecessors:
if self._predecessors[nodeIndexB] == nodeIndexA:
[docs] def remove_edges(self, edges):
Removes edges from the tree if they appear in the tree.
:param edges: List of tuples containing node indices
(e.g. [(nodeIndexA,nodeIndexB),(nodeIndexA,nodeIndexC),...])
for edge in edges:
self.remove_edge(edge[0], edge[1])
[docs] def clear(self):
Remove all nodes and edges from the graph, but
keep its canonName.
self._successors = OrderedDict()
self._predecessors = {}
self._nodesMapping = {}
self._nodeCanonNames = {}
self._finalStates = {}
self._rootIndex = None
[docs] def indexToNode(self, nodeIndex):
Returns the node object with index nodeIndex.
If nodeIndex is a list of indices, return the corresponding
list of node objects.
:param nodeIndex: Integer or list of integers of
node indices.
:return: Node object or list of Node objects
if isinstance(nodeIndex,int):
return self._nodesMapping[nodeIndex]
elif isinstance(nodeIndex,list):
return [self._nodesMapping[n] for n in nodeIndex]
elif isinstance(nodeIndex,tuple):
return tuple([self._nodesMapping[n] for n in nodeIndex])
raise SModelSError("Can not convert object of type %s to nodes" %str(type(nodeIndex)))
[docs] def daughterIndices(self, nodeIndex, ignoreInclusiveNodes=False):
Returns the list of node indices corresponding to the
daughters of nodeIndex.
:param nodeIndex: Parent node index
:param ignoreInclusiveNodes: If True, skips inclusive nodes
daughters = self._successors[nodeIndex]
if ignoreInclusiveNodes:
daughters = [d for d in daughters[:]
if not self.indexToNode(d).isInclusive]
return daughters
[docs] def daughters(self, nodeIndex, ignoreInclusiveNodes=False):
Returns the list of node objects corresponding to the
daughters of nodeIndex.
:param nodeIndex: Parent node index
:param ignoreInclusiveNodes: If True, it skips inclusive nodes
daughterIndices = self.daughterIndices(nodeIndex, ignoreInclusiveNodes)
daughters = self.indexToNode(daughterIndices)
return daughters
[docs] def parentIndex(self, nodeIndex):
Returns the node index corresponding to the
parent of nodeIndex.
:param nodeIndex: Daughter node index
return self._predecessors[nodeIndex]
[docs] def parent(self, nodeIndex):
Returns the node object corresponding to the
parent of nodeIndex.
:param nodeIndex: Daughter node index
parentIndex = self.parentIndex(nodeIndex)
parent = self.indexToNode(parentIndex)
return parent
def rootIndex(self):
Returns the index of the root node (primary vertex) of the tree.
If it has not been defined, compute it.
:return: root node index
if self._rootIndex is None:
root = [nodeIndex for nodeIndex in self.nodeIndices
if self.in_degree(nodeIndex) == 0]
if len(root) != 1:
raise SModelSError("Malformed Tree, %i root(s) have been found." % len(root))
self._rootIndex = root[0]
return self._rootIndex
def root(self):
Returns the root node (primary vertex) of the tree.
If it has not been defined, compute it.
:return: root node
rootIndex = self.rootIndex
root = self.indexToNode(rootIndex)
return root
def nodeIndices(self):
Returns the tist of node indices in the Tree.
:return: List of indices (int)
nodeIndexList = list(self._successors.keys())
return nodeIndexList
def edgeIndices(self):
Returns the list of edges indices (pairs of integers) in the Tree.
:return: List of edge indices
edgesList = []
for n in self.nodeIndices:
edgesList += list(product([n],self.daughterIndices(n)))
return edgesList
def nodes(self):
Returns the tist of ParticleNode objects in the Tree.
:return: List of ParticleNode objects
nodeList = self.indexToNode(self.nodeIndices)
return nodeList
def edges(self):
Returns the list of edges (pairs of node objects) in the Tree.
:return: List of edges
edgesList = [tuple(self.indexToNode(edgeTuple)) for edgeTuple in self.edgeIndices]
return edgesList
def canonName(self):
Returns the canonName. If not defined, it will be computed.
:return: Canonical name (int)
if self.rootIndex not in self._nodeCanonNames:
return self._nodeCanonNames[self.rootIndex]
[docs] def computeCanonName(self, nodeIndex=None):
Recursively sets the canonName for each node.
Returns the canonical name in integer form.
:param nodeIndex: Node index to set the name for. If None, it will use the root
:return: Integer representing the Tree canonical name
if not self.number_of_nodes():
return None
if nodeIndex is None:
nodeIndex = self.rootIndex
# Set the final state
node = self.indexToNode(nodeIndex)
# If it is inclusive node set its name to an inclusive integer
# and return its name (no need to check the children)
if node.isInclusive or node.inclusiveList:
canonName = InclusiveValue()
self._nodeCanonNames[nodeIndex] = canonName
return canonName
children = self.daughterIndices(nodeIndex)
if not children:
canonName = 10
tp = [self.computeCanonName(n) for n in children]
if any(isinstance(name, InclusiveValue) for name in tp):
canonName = InclusiveValue()
tp = sorted(tp)
tpStr = '1' + "".join(str(c) for c in tp) + '0'
canonName = int(tpStr)
self._nodeCanonNames[nodeIndex] = canonName
return canonName
[docs] def nodeCanonName(self,nodeIndex):
Returns the canon name for the node.
:param nodeIndex: Index of the node
:return: Canonical name (int)
if nodeIndex not in self._nodeCanonNames:
return self._nodeCanonNames[nodeIndex]
[docs] def out_degree(self, nodeIndex):
Computes the number of outgoing edges from the node
(number of daughters).
:param nodeIndex: Node index (int)
:return: Number of outgoing edges (int)
if nodeIndex not in self.nodeIndices:
return 0
return len(self.daughterIndices(nodeIndex))
[docs] def in_degree(self, nodeIndex):
Computes the number of incoming edges to the node
(number of parents).
:param nodeIndex: Node index (int)
:return: Number of incoming edges (1 or 0)
if nodeIndex in self._predecessors:
if self._predecessors[nodeIndex] is not None:
return 1
return 0
[docs] def number_of_nodes(self):
Returns the total number of nodes in the Tree.
:return: Number of nodes (int)
return len(self.nodeIndices)
[docs] def genIndexIterator(self, nodeIndex=None,
includeLeaves=False, ignoreInclusiveNodes=False):
Returns an iterator over the generations (mother and its daughters)
of node indices starting at nodeIndex using a breadth first search.
:param nodeIndex: Node index from tree. If None, starts at tree root.
:param includeLeaves: If True, it will consider the leaves (undecayed nodes)
as moms in the iterator (with an empty daughters list)
:param ignoreInclusiveNodes: If True, it skip inclusive nodes and its descendents.
:return: Iterator over nodes.
if nodeIndex is None:
nodeIndex = self.rootIndex
mom = nodeIndex
if ignoreInclusiveNodes:
if self.indexToNode(mom).isInclusive:
return []
daughters = self.daughterIndices(mom,ignoreInclusiveNodes)
generation = [(mom, daughters)]
while generation:
for pair in generation:
yield pair
next_generation = []
for pair in generation:
mom, daughters = pair
for new_mom in daughters:
if new_mom not in self.nodeIndices:
new_daughters = self.daughterIndices(new_mom,ignoreInclusiveNodes)
if new_daughters or includeLeaves:
next_generation.append((new_mom, new_daughters))
generation = next_generation
[docs] def dfsIndexIterator(self, nodeIndex=None, ignoreInclusiveNodes=False):
Iterates over the node indices following a depth-first traversal of the tree
starting at nodeIndex. If nodeIndex is None, include all nodes.
:param nodeIndex: Node index to which start the iterator
(the corresponding node is NOT included in the iterator)
:param ignoreInclusiveNodes: If True, it skip inclusive nodes and its descendents.
:return: Iterator over node indices
if nodeIndex is None:
nodeIndex = self.rootIndex
yield self.rootIndex
nodes = self.daughterIndices(nodeIndex,ignoreInclusiveNodes)
visited = set() # Store visited nodes
depth_limit = len(nodes)
# Loop over nodes
for node in nodes:
if node in visited:
yield node
daughters = [iter(self.daughterIndices(node,
while daughters:
last_children = daughters[-1] # Get the last added children
child = next(last_children)
if child not in visited:
yield child
# Add daughters of child
except StopIteration:
# All the children have been visited, remove from list
[docs] def bfs_sort(self, numberNodes=False):
Sort the nodes according to their appearence in a
breadth first search.
:param numberNodes: If True, renumber the nodes according to their bfs order
:return: Dictionary with old node indices as keys
and new indices as values.
orderedList = []
for nodeIndex, _ in self.genIndexIterator(includeLeaves=True):
# Sort according to the bfs order
# Re-number nodes according to their order
if numberNodes:
indexDict = {n : orderedList.index(n) for n in self.nodeIndices}
# Define dummy dict
indexDict = {n : n for n in self.nodeIndices}
return indexDict
[docs] def sortAccordingTo(self,indicesList):
Sort the nodes according to their order in indicesList.
:param indicesList: List of node indices used to sort the nodes.
newSuccessors = OrderedDict()
indices = self.nodeIndices
sortList = indicesList[:]
# Indices not present in indicesList are left
# put at the end of the list
for nodeIndex in indices:
if nodeIndex not in sortList:
sortedIndices = sorted(indices, key = lambda n: sortList.index(n))
# Go over sorted indices and create new successors dict
for nodeIndex in sortedIndices:
daughters = self.daughterIndices(nodeIndex)
# Sort daughters according to the list
sortedDaughters = sorted(daughters, key = lambda n: sortList.index(n))
newSuccessors[nodeIndex] = sortedDaughters[:]
self._successors = newSuccessors
[docs] def sort(self, nodeIndex=None, force=False):
Sort subtree of self generated by nodeIndex.
If nodeIndex is None, sort the tree and re-number the nodes
according to the bfs order (after sorting).
If the self is already tagged as sorted and force = False,
do nothing.
:param nodeIndex: Node index
:param force: If True, will sort even if self is tagged as sorted.
:return: Dictionary with old node indices as keys
and new indices as values.
# If tree is already sorted, do nothing
if hasattr(self,'_sorted') and self._sorted and not force:
if nodeIndex is None:
cName = self.canonName # Just to make sure canonName is defined
if cName is None:
nodeIndex = self.rootIndex
daughters = self.daughterIndices(nodeIndex)
if daughters:
for d in daughters:
self.sort(d, force=force)
sorted_daughters = self.sortSubTrees(daughters)
# Remove nodeIndex -> daughters edges
# Add edges with the correct ordering:
# Finally, after sorting the subtrees,
# make sure the nodes are sorted and numbered according
# to the generations (breadth-first search)
if nodeIndex == self.rootIndex:
nodeIndexMap = self.bfs_sort(numberNodes=True)
# Tag the tree as sorted
self._sorted = True
return nodeIndexMap
[docs] def sortSubTrees(self, subtreeList):
Sorts a list of subtrees of self generated by the nodes
in subtreeList.
:param subtreeList: List of node indices to be considered as roots
of the subtrees.
:return: Sorted list of node indices.
if len(subtreeList) == 1 or len(subtreeList) == 0:
return subtreeList
# First group subtrees by their canonical name:
nameDict = {}
for nodeIndex in subtreeList:
cName = self.nodeCanonName(nodeIndex)
if cName not in nameDict:
nameDict[cName] = [nodeIndex]
# Now sort subtrees with common canonical names:
sorted_trees = []
for cName in sorted(nameDict.keys()):
# Within equal canonincal names sort daughters by the generated subtrees
sorted_trees += self.sortCommonSubTrees(nameDict[cName])
return sorted_trees
[docs] def sortCommonSubTrees(self, subtreeList):
Sorts a list of subtrees of self generated by the nodes
in subtreeList using a quicksort algorithm.
All the subtrees should have a common topology
(same canonical name).
:param subtreeList: List of node indices
:return: Sorted list of node indices.
if len(subtreeList) == 1 or len(subtreeList) == 0:
return subtreeList
s = subtreeList[:]
pivot = s[0]
i = 0
for j in range(len(s)-1):
if self.compareSubTrees(self, s[j+1], pivot) < 0:
s[j+1], s[i+1] = s[i+1], s[j+1]
i += 1
s[0], s[i] = s[i], s[0]
first_part = self.sortCommonSubTrees(s[:i])
second_part = self.sortCommonSubTrees(s[i+1:])
sortedList = first_part + second_part
return sortedList
[docs] def compareSubTrees(self, other, n1, n2):
Compare the subtrees generated by the nodeIndex n1 in self
and the nodeIndex n2 in other.
:param other: TheorySMS object (if other=self compare subtrees of the same SMS).
:param n1: Node index for the root of subtree1
:param n2: Node index for the root of subtree2
:return: 0, if subtrees are equal, -1 if subtree1 < subtree2, 1 if subtree1 > subtree2
# Compare node canonical names
cName1 = self.nodeCanonName(n1)
cName2 = other.nodeCanonName(n2)
if cName1 != cName2:
if cName1 < cName2:
return -1
return 1
# Compare nodes
cmp = self.compareNodes(other,n1,n2)
if cmp != 0:
return cmp
daughters1 = self.daughterIndices(n1)
daughters2 = other.daughterIndices(n2)
# If nodes are leaves, return 0
if len(daughters1) == len(daughters2) == 0:
return 0
# Check if the daughters from n2 match the ones from n1:
# (the daughters should be sorted at this point)
for i1, d1 in enumerate(daughters1):
d2 = daughters2[i1]
cmp = self.compareSubTrees(other, d1, d2)
if cmp != 0:
return cmp
return 0
[docs] def compareNodes(self,other,nodeIndex1,nodeIndex2):
Convenience function for defining how nodes are compared
within the SMS.
:param other: TheorySMS object (if other=self compare subtrees of the same SMS).
:param nodeIndex1: Index of first node
:param nodeIndex2: Index of second node
:return: 1 if node1 > node2, -1 if node1 < node2, 0 if node1 == node2.
# Comparison parameters:
node1 = self.indexToNode(nodeIndex1)
node2 = other.indexToNode(nodeIndex2)
# Directly use node comparison:
cmp = node1.compareTo(node2)
return cmp
[docs] def copyTreeFrom(self, other, nodesObjDict):
Replaces the tree structure (nodes, edges, indices,...)
by the structure in other. Uses the nodesObjDict to set the
new node mapping (nodeIndex > nodeObj).
:param other: SMS object
:param nodesObjDict: Dictionary where keys are node indices (from other)
and values are node objects.
self._successors = {n : dList[:] for n,dList in other._successors.items()}
self._predecessors = {d : parent for d,parent in other._predecessors.items()}
self._canonName = other._canonName
self._rootIndex = other._rootIndex
self._nodesMapping = {nodeIndex : n for nodeIndex,n in nodesObjDict.items()}
self._nodeCanonNames = {nodeIndex : cName for nodeIndex,cName
in other._nodeCanonNames.items()}
self._finalStates = {nodeIndex : pList[:] for nodeIndex,pList
in self._finalStates.items()}
[docs] def treeToString(self,
Convert the tree to a process string (e.g. '(PV(0) > gluino(1),squark(2)), (gluino(1) >
MET(3),jet(4),jet(5)), (squark(2) > HSCP(6),u(7))')
Node indices can be removed from specific particles using the removeIndicesFrom option. Allowed values are:
None -> keep all indices
'SM' -> remove indices from SM particles
'stable' -> remove indices from stable (undecayed) particles
'all' -> remove indices from all particles
The default is to remove indices from stable particles.
:param removeIndicesFrom: If defined, will remove indices from particles according to their properties.
:return: String describing the process
if removeIndicesFrom not in [None, 'SM', 'stable', 'all']:
raise SModelSError("removeIndicesFrom = %s value not accepted for treeToString" %removeIndicesFrom)
rmFrom = removeIndicesFrom
smsStr = ""
rootIndex = self.rootIndex
for momIndex, daughterIndices in self.genIndexIterator(rootIndex):
# Convert from indices to node objects
mom = self.indexToNode(momIndex)
daughters = self.indexToNode(daughterIndices)
if momIndex == rootIndex: # Always remove from PV
smsStr += '(%s > ' % (mom)
elif rmFrom is not None:
if rmFrom == 'all':
smsStr += '(%s > ' % (mom)
elif rmFrom == 'SM' and mom.isSM:
smsStr += '(%s > ' % (mom)
elif rmFrom == 'stable' and self.out_degree(momIndex) == 0:
smsStr += '(%s > ' % (mom)
else: # Add index
smsStr += '(%s(%i) > ' % (mom, momIndex)
else: # If None, always add index
smsStr += '(%s(%i) > ' % (mom, momIndex)
# Add daughters
for iD, d in enumerate(daughters):
dIndex = daughterIndices[iD]
if rmFrom is not None:
if rmFrom == 'all':
smsStr += '%s,' % (d)
elif rmFrom == 'SM' and d.isSM:
smsStr += '%s,' % (d)
elif rmFrom == 'stable' and self.out_degree(dIndex) == 0:
smsStr += '%s,' % (d)
else: # Add index
smsStr += '%s(%i),' % (d, dIndex)
else: # If None, always add index
smsStr += '%s(%i),' % (d, dIndex)
smsStr = smsStr[:-1] + '), '
smsStr = smsStr[:-2]
return smsStr
[docs] def treeToBrackets(self):
Convert the Tree to a nested list with the Z2 even
final states. The Z2 odd final states (e.g. 'MET', 'HSCP') are
not included. The Tree must be Z2-preserving and represent the
pair production cascade decay of two Z2-odd particles.
:return: Nested list with the strings for the Z2-even final states
(e.g. [[['e-','mu'],['L']],[['jet']]])
branches = self.daughters(self.rootIndex)
branchIndices = self.daughterIndices(self.rootIndex)
finalState = []
intermediateState = []
branchList = []
if len(branches) != 2:
raise SModelSError("Can not convert tree to bracket with %i branches" % len(branches))
for ib, b in enumerate(branches):
bIndex = branchIndices[ib]
# Deal separately with the case where the primary mother is stable:
if self.out_degree(bIndex) == 0:
if not b.isSM:
raise SModelSError("Can not convert tree with Z2-violating decays to bracket")
for momIndex, daughterIndices in self.genIndexIterator(bIndex):
# Convert from indices to node objects
mom = self.indexToNode(momIndex)
daughters = self.indexToNode(daughterIndices)
vertexList = [str(d) for d in daughters if d.isSM]
fstates = []
for idaughter, daughter in enumerate(daughters):
if daughter.isSM:
if self.out_degree(daughterIndices[idaughter]) != 0:
if daughters:
if len(vertexList) != len(daughters) - 1 or len(fstates) > 1:
raise SModelSError("Can not convert tree with Z2-violating decays to bracket: \n %s" % self.treeToString())
finalState += fstates
return branchList, finalState, intermediateState
[docs] def getFinalStates(self, nodeIndex=None):
Get the list of nodes which have not decayed (appear at the top of the tree).
If source is defined, get the final states generated by the cascade decay of the source
node. It also caches the finalState to self._finalStates.
:param nodeIndex: Node index for which to get the final states.
:returns: list of node indices
if nodeIndex is None:
nodeIndex = self.rootIndex
if nodeIndex in self._finalStates:
return self._finalStates[nodeIndex]
# For leaves, the final state is themselves:
if self.out_degree(nodeIndex) == 0:
self._finalStates[nodeIndex] = [nodeIndex]
finalStates = []
for d in self.daughterIndices(nodeIndex):
finalStates += self.getFinalStates(d)
self._finalStates[nodeIndex] = finalStates
return self._finalStates[nodeIndex]
[docs] def relabelNodeIndices(self,nodeIndexDict):
Relabel node indices according to nodeIndexDict.
For node indices not appearing in nodeIndexDict nothing is done.
:param nodeIndexDict: Dictionary with current node indices as keys
and new indices as values
if any(nodeIndex not in nodeIndexDict for nodeIndex in self.nodeIndices):
raise SModelSError("Dictionary for relabelling nodes must contain all node indices")
newMapping = {}
newSuccessors = OrderedDict()
newPredecessors = {}
newCanonNames = {}
newFinalStates = {}
for oldIndex, newIndex in nodeIndexDict.items():
# Update daughter indices
newDaughters = []
for d in self.daughterIndices(oldIndex):
if d in nodeIndexDict:
newPredecessors.update({d : newIndex
for d in newDaughters})
# Add entry to newSuccessors dict:
newSuccessors[newIndex] = newDaughters
# Add entry to newMapping dict
newMapping[newIndex] = self.indexToNode(oldIndex)
# Add entry to canonNames dict
if oldIndex in self._nodeCanonNames:
newCanonNames[newIndex] = self._nodeCanonNames[oldIndex]
if oldIndex in self._finalStates:
newFinalStates[newIndex] = self._finalStates[oldIndex]
# Update dicts:
self._successors = newSuccessors
self._predecessors = newPredecessors
self._nodesMapping = newMapping
self._nodeCanonNames = newCanonNames
self._finalStates = newFinalStates
self._rootIndex = nodeIndexDict[self.rootIndex]
[docs] def updateNodeObjects(self, nodeObjectDict):
Update the node index -> node object mapping.
Only affects the indices appearing in nodeObjectDict.
:param nodeObjectDict: Dictionary with current node indices as keys
and new node objects as values
for nodeIndex, newObj in nodeObjectDict.items():
self._nodesMapping[nodeIndex] = newObj
[docs] def checkConsistency(self):
Make sure the tree has the correct topology(directed rooted tree).
Raises an error otherwise.
malformedTree = False
if len(self.nodeIndices) > 1:
rootIndex = self.rootIndex
# Check if root has no parents and at least one daughter
if self.in_degree(rootIndex) != 0 or self.out_degree(rootIndex) == 0:
malformedTree = True
nNodes = len(self.nodeIndices)
# Check if all nodes (except root) have a unique parent
if any(self.in_degree(nodeIndex) != 1 for nodeIndex in self.nodeIndices
if nodeIndex != rootIndex):
malformedTree = True
# Check if all nodes can be reached from the root node
if len(list(self.dfsIndexIterator())) != nNodes:
malformedTree = True
if malformedTree:
raise SModelSError("Graph created with malformed structure (not a tree).")
[docs] def switchBranches(self):
If the SMS has a two branch structure (PV > X,Y), return
a new SMS with its branches switched (PV > Y,X).
Otherwise return None.
:return: A new SMS object with the branches switched or None.
if len(self.daughterIndices(self.rootIndex)) != 2:
return None
smsNew = self.copy()
branchIndices = smsNew.daughterIndices(smsNew.rootIndex)
nodes = {}
edges = {}
for bIndex in branchIndices:
nodes[bIndex] = smsNew.indexToNode(bIndex)
edges[bIndex] = [(bIndex,d) for d in smsNew.daughterIndices(bIndex)]
for bIndex in branchIndices[::-1]:
return smsNew
[docs] def draw(self, particleColor='steelblue2',
attrUnit=None, filename=None, view=True,
graph_kwargs={'layout' : 'dot', 'ranksep' : '0.3', 'rankdir' : "LR"},
nodes_kwargs={'style' : 'filled', 'fontsize' : '10', 'color' : 'black','shape' : 'circle','margin' : '0'},
edges_kwargs={'arrowhead' : 'vee', 'arrowsize' : '0.7', 'color' : 'grey53'}):
Draws Tree using matplotlib.
:param particleColor: color for particle nodes
:param smColor: color used for particles which have the isSM attribute set to True
:param pvColor: color for primary vertex
:param fontsize: Font size for labels
:param labelAttr: attribute to be used as label. If None, will use the string representation of the node object. It can also be a dictionary with node indices as keys and the label strings as values.
:param attrUnit: Unum object with the unit to be removed from label attribute(if applicable)
:param filename: Filename to save drawing to.
:param view: open a viewer after plotting
:param maxLabelSize: Maximum size for the label string for the node. If the label is larger, it will be truncated.
If None/False/0, it will keep the full label.
:param usePVimage: Path to a image file (png, bmp or jpeg) to be used instead of the primary vertex (PV) node.
:param graph_kwargs: Dictionary with graph attributes to be used.
:param nodes_kwargs: Dictionary with nodes attributes to be used.
:param edges_kwargs: Dictionary with nodes attributes to be used.
:return: Display a GraphViz Digraph object, if view is true (and save it to file if filename is defined)
import graphviz
except ImportError:
raise SModelSError("For drawing SMS objects, please install graphviz")
nodesAndIndices = zip(self.nodes,self.nodeIndices)
if labelAttr is None:
labels = {nodeIndex: "" for _,nodeIndex in nodesAndIndices}
elif isinstance(labelAttr,dict):
labels = {k : v for k,v in labelAttr.items()}
elif labelAttr == 'label':
labels = {nodeIndex: str(n) for n,nodeIndex in nodesAndIndices}
elif attrUnit is not None:
labels = {nodeIndex: str(getattr(n, labelAttr).asNumber(attrUnit))
if (hasattr(n, labelAttr) and getattr(n, labelAttr) is not None)
else str(n) for n,nodeIndex in nodesAndIndices}
elif labelAttr == 'node':
labels = {nodeIndex: str(nodeIndex) for _,nodeIndex in nodesAndIndices}
elif labelAttr == 'canonName':
labels = {nodeIndex: str(self.nodeCanonName(nodeIndex))
for _,nodeIndex in nodesAndIndices}
labels = {nodeIndex: str(getattr(n, labelAttr)) if hasattr(n, labelAttr)
else str(n) for n,nodeIndex in nodesAndIndices}
for key in labels:
if labels[key] == 'anyOdd':
labels[key] = 'BSM'
node_color = {}
for n in self.nodeIndices:
node = self.indexToNode(n)
if n == self.rootIndex:
node_color[n] = pvColor
elif hasattr(node, 'isSM') and node.isSM:
node_color[n] = smColor
node_color[n] = particleColor
# Truncate labels if needed:
if maxLabelSize:
for key,val in labels.items():
if len(val) > maxLabelSize:
labels[key] = val[:maxLabelSize]+'...'
dot = graphviz.Digraph()
for key,val in graph_kwargs.items():
dot.attr(**{key : str(val)})
for nodeIndex in self.nodeIndices:
if labels[nodeIndex] == 'PV' and usePVimage:
fontsize='16', label="",color='white',
width='0.6', height='1.2', fixedsize='true',
nodeAttrs = {k : v for k,v in nodes_kwargs.items()}
if 'label' not in nodeAttrs:
nodeAttrs['label'] = labels[nodeIndex]
if 'fillcolor' not in nodeAttrs:
nodeAttrs['fillcolor'] = node_color[nodeIndex]
dot.node(str(nodeIndex), **nodeAttrs)
for edgeIndex in self.edgeIndices:
# If filename is defined, save image
if filename is not None:
import os
filename = os.path.abspath(filename)
# dot.format = extension[1:]
dot.render(outfile=filename, view=view, cleanup=True)
# Try to display (for various circumstances)
if view:
display(dot) # for notebooks
except NameError:
import os
fname = filename
if fname != None:
fname, _ = os.path.splitext(filename)
dot.view(filename=fname) # for terminals
except (RuntimeError, graphviz.ExecutableNotFound,\
graphviz.CalledProcessError) as e: