Generating decision trees

The crux of the problem is to transform a case statement into a decision tree. A case statement has a value, a sequence of arms, and a trailer. Each arm has a pattern, and code to be executed. When the case statement is executed, it chooses the first arm whose pattern matches the value, then executes the corresponding code, then executes the trailer. I generate a decision tree to do the job. Each internal node of the decision tree tests a field of the value. It then chooses an edge (child) based on that value, and continues testing fields until it reaches a leaf, at which time it executes the code associated with that leaf.

The goal of tree generation is not to generate just any tree, but the tree with the fewest nodes. This problem is NP-complete, so I apply a few heuristics. The results, at least for the machine descriptions I use, seem to be as good as what I would come up with by hand.

The arms of the case statement have some extra information. The file and line number help with error message and make it possible to generate #line statements that identify the source of the code. The original arm gives the arm from which the current arm is derived, and is useful for many of the heuristics.

<*>= [D->]
record caserec(arms,valcode,trailer)
                # case arms, code to compute value, trailing code
record arm(file, line, pattern, code, original)
                # pattern and code are the content
                # line, file, original(pattern) are used for error reporting
Defines arm, caserec (links are to index).

Each node of the decision tree is associated with a particular case statement. Internal nodes have children, and a field which says which field we decided to test on. The edges that point to the children record the interval of values for the particular child. Leaf nodes have a name that records the name of the pattern known to match at that leaf node.

<*>+= [<-D->]
record node(cs, children, field, name)
        # case statement, list of edges to children, field chosen, pattern name
        #       (name field used to support name operator, assigned only to leaves)
record edge(node, lo, hi)
        # node pointed to and lo and hi interval of field for this edge
Defines edge, node (links are to index).

To create a decision tree, I begin with a node containing the full, original case statement. I then use a ``work queue'' approach to check each node and see if it needs to be split. If no pattern matches the node, or if the first pattern always matches (with a unique name), no further splitting needs to be done, and I assign a name to the leaf. [If the name isn't used, I assign the name "-unused-", because that will make it easier to combine nodes in the dagging phase.] Otherwise, I split the node.

<*>+= [<-D->]
procedure needs_splitting(n)
    if *n.cs.arms = 0 then fail
    p := n.cs.arms[1].pattern
    name := \p.disjuncts[1].name | p.name
    every d := !p.disjuncts do {
        n := \d.name | p.name
        if n ~=== name then return   # different names, needs splitting
        else if *d.constraints = 0 then fail    # always matches, needn't split
    }
    return                      # pattern doesn't always match -> split
end

procedure tree(cs)
    static heuristics
    initial heuristics := [leafarms, childarms, nomatch, childdisjuncts, branchfactor]

    root := node(cs)
    work := [edge(root)]        # work queue of edges (nodes) to be expanded
    while n := get(work).node do
        if needs_splitting(n) then {
            <split node n and add children to work queue>
        } else {
            if *n.cs.arms = 0 then 
                n.name := "-NOMATCH-"
            else if n.cs.arms[1].code ? find_id("name") then {
                p := n.cs.arms[1].pattern
                n.name := \p.disjuncts[1].name | \p.name | "-unnamed-"
            } else 
                n.name := &null
            if \mapnames then n.name := map(\n.name)
       }
    return root
end
Defines needs_splitting, tree (links are to index).

Splitting a node involves choosing a field, finding out which intervals of values of that field are interesting, and creating a child node for each such interval of values. The patterns in the case statement of the child node reflect the knowledge of the value interval of the tested field.

I make the decision by splitting the node on each field mentioned in the case statement. I then compute some heuristic functions of the children from each splitting and use the best-scoring field.

Some debugging information may be written to hdebug or sdebug.

<split node n and add children to work queue>= (<-U)
fields := mentions(n.cs)
*fields > 0 | impossible("internal node mentions no fields")
candidates := table()
every f := !fields do
    candidates[f] := split(n, f)
<if debugging, split all and report>
*fields > 1 & write(\hdebug, "Choosing one of ", patimage(fields))
every h := !heuristics do {
    if *fields = 1 then break
    fields := findmaxima(h, candidates, fields)
write(\hdebug, image(h), " chose ", patimage(fields))
}
*fields > 0 | impossible("no fields")
*fields = 1 | write(\hdebug, "tie among fields", patimage(fields), " near ",
                      image(n.cs.arms[1].original.file), ", line ",
                      n.cs.arms[1].original.line)
work |||:= n.children := candidates[n.field := ?fields]
<if debugging, split all and report>= (<-U)
if \tryall & \hdebug & *fields > 1 then     {
  write(\hdebug, repl("=",10), " Splitting ", repl("=", 10))
  every findmaxima(!heuristics, candidates, fields) do write(\hdebug)
  write(\hdebug, repl("=", 30), "\n")
}

To split a node, I look at each interval of values that might be interesting. I apply that interval to the case statement, and if there can be any match, I create and add a new child node.

<*>+= [<-D->]
procedure split(n, f)
    local vals,v,d,val,c,p,j,i,newd,cst,child,newp

    patterns := []
    children := []
    every put(patterns, (!n.cs.arms).pattern)
    r := intervals(patterns, f)
    <if debugging, write about splitting this node>

    every i := 1 to *r - 1 do
        put(children, edge(node(apply(n.cs, f, r[i], r[i+1]),[]), r[i], r[i+1]))

    write(\sdebug, "Done splitting.\n")
    return children
end    
Defines split (links are to index).

<if debugging, write about splitting this node>= (<-U)
writes(\sdebug, "Splitting ")
outpattern(\sdebug, patterns[1])
every i := 2 to *patterns do { writes(\sdebug, " | "); outpattern(\sdebug, patterns[i])}
write(\sdebug, " on ", f.name)

So, what is the new case statement that results from applying lo <=f < hi to cs? For each arm, I match the pattern against the interval. If it succeeds, I create a new arm for the new case statement, containing the reduced pattern.

<*>+= [<-D->]
procedure apply(cs, f, lo, hi)
    result := copy(cs)
    result.arms := []
    write(\sdebug, "    Applying ", stringininterval(f.name, lo, hi))
    every a := !cs.arms do
        put(result.arms, 
            arm(a.file, a.line, pmatch(a.pattern, f, lo, hi), a.code, a.original))
    if alwaysmatches(result.arms[1].pattern) then
        result.arms := [result.arms[1]]
    return result
end

# if lo <= f < hi and p matches, return the new p

procedure pmatch(p, f, lo, hi)
    result := pattern([], p.name)
    every d := !p.disjuncts do 
        if c := !d.constraints & c.field === f then     # disjunct mentions f
            if c.lo <= lo & hi <= c.hi then {           # this constraint is matched
                newd := disjunct([], d.name)
                every c := !d.constraints & c.field ~=== f do 
                    put(newd.constraints, c)
                put(result.disjuncts, newd)
            } else
                c.hi <= lo | c.lo >= hi | impossible("bad intervals")
        else                                             # disjunct does not mention f
            put(result.disjuncts, d)
    <if debugging, write about results of pmatch>
    if *result.disjuncts > 0 then return result
end

Defines apply, pmatch (links are to index).

<if debugging, write about results of pmatch>= (<-U)
if *result.disjuncts > 0 then writes(\sdebug, "        ===> ") & outpattern(\sdebug, p)
# else writes(\sdebug, "             ") & outpattern(\sdebug, p)

if *result.disjuncts > 0 then write(\sdebug, " matches") 
# else write(\sdebug, " does not match")

Tree-minimization heuristics

First, the boilerplate that takes a heuristic h, candidate splittings, and a set of fields, and returns the set of fields with the largest score on h.

<*>+= [<-D->]
procedure findmaxima(h, candidates, fields)
    local max
    S := []
    every f := !fields do {
        score := h(candidates[f], f)
        write(\hdebug,"Field ", f.name, " scores ", score, " on ", image(h))
        /max := score - 1
        if score > max then {
            max := score
            S := [f]
        } else if score = max then
            put(S, f)
    }
    return set(S)
end
Defines findmaxima (links are to index).

Here's a big pile of heuristics. I'm not sure I've ever needed more than the first two, but they're amusing and easy enough to write.

<*>+= [<-D->]
# leafarms: prefer candidate with most arms that appear at leaf
#           nodes.  Each original arm counted only once.
#           Not matching is also counted as an arm.

procedure leafarms(children, f) 
    arms := set()
    every n := (!children).node & *n.cs.arms > 0 do
       if not needs_splitting(n) then 
           insert(arms, n.cs.arms[1].original)
    return *arms + if *(!children).node.cs.arms = 0 then 1 else 0
end

# childarms: prefer the candidate with the fewest arms in children

procedure childarms(children, f)
    sum := 0
    every sum -:= *(!children).node.cs.arms
    return sum
end

# nomatch: if tied on leafarms and childarms, take candidate
#          with real leaf in preference to nomatch leaf

procedure nomatch(children, f)
    return if *(!children).node.cs.arms = 0 then -1 else 0
end

# childdisjuncts: prefer the candidate with the fewest disjuncts in children

procedure childdisjuncts(children, f)
    sum := 0
    every sum -:= *(!(!children).node.cs.arms).pattern.disjuncts
    return sum
end

# branchfactor:  prefer the candidate with the fewest children

procedure branchfactor(children, f)
    return - *children
end
Defines branchfactor, childarms, childdisjuncts, leafarms, nomatch (links are to index).

Utility functions

<*>+= [<-D->]
# If f is to be used to split patterns, what intervals need to be considered?

procedure intervals(patterns, f)
    cuts := set([0, 2^(f.hi - f.lo)])
    every p := !patterns & d := !p.disjuncts & c := !d.constraints & c.field === f do
        every insert(cuts, c.lo | c.hi)
    return sort(cuts)
end

# what fields are mentioned in a case statement?

procedure mentions(cs) 
    result := set()
    every a := !cs.arms & d := !a.pattern.disjuncts & c := !d.constraints do
       insert(result, c.field)
    return result
end

# find_id: tab to and past identifier id, returning its position
# ignores quotes, comment brackets

procedure find_id(id)
    static notlnum
    initial notlnum := ~ (&letters ++ &digits ++ '_')
    tab(p := find(id)) & p = 1 | (move(-1) & any(notlnum) & move(1)) &
               =id & pos(0) | any(notlnum) & suspend p
end
Defines find_id, intervals, mentions (links are to index).

Tree checking

Once the tree is generated, it's useful to check it for redundant arms and for arms that never match. These checks will help users catch mistakes in their specifications. Note that I must check the ``original'' arms; that's why they're there.

<*>+= [<-D]
procedure checktree(n)
    originals := set()
    every insert(originals, (!n.cs.arms).original)
    deletematching(n, originals)
    every a := !originals do 
        warning("No word matches pattern at ", image(a.file), ", line ", a.line)
    if hasnomatch(n) then
        warning("Case statement at ", image(n.cs.arms[1].file), ", line ",
                n.cs.arms[1].line - 1, " doesn't cover all cases")
    return n
end

procedure deletematching(n, originals)
    if *originals = 0 then return
    else if *n.children > 0 then every deletematching((!n.children).node, originals)
    else every delete(originals, (!n.cs.arms).original)
end

procedure hasnomatch(n)
    if *n.children > 0 then return hasnomatch((!n.children).node)
    else if *n.cs.arms = 0 then return  # found it
end
Defines checktree, deletematching, hasnomatch (links are to index).

Indices

Chunks

Identifiers