Micro C, Part 2: Semantic Analysis
llvm, haskell
Lastmod: 2020-04-17

Having written the parser for our C dialiect in part 1, we now proceed to the semantic analysis pass. Here, our goal is to catch as many errors as possible before going on to code generation. In particular, we will be somewhat stricter than real C compilers and disallow implicit conversions because they are both annoying to implement and lead to subtle bugs. Don't worry, though, we can still do whatever pointer casting shenanigans we want to, just as long as we do them explicitly.

The Semantically checked Abstract Syntax Tree (SAST)

To ensure that our semantic analysis doesn't leave any part of the AST unchecked, we define a new type, the SAST, that captures as much type checking information as possible, essentially duplicating each node of the AST and adding a type to it. In the grand scheme of solutions to the problem of adding information to the AST, this is the most "low-tech," but it works quite well for our purposes.

module Microc.Sast where

import           Microc.Ast
import           Data.Text                      ( Text )

type SExpr = (Type, SExpr')
data SExpr' =
    SLiteral Int
  | SFliteral Double
  | SStrLit Text
  | SCharLit Int
  | SBoolLit Bool
  | SNull
  | SBinop Op SExpr SExpr
  | SUnop Uop SExpr
  | SCall Text [SExpr]
  | SCast Type SExpr
  | LVal LValue
  | SAssign LValue SExpr
  | SAddr LValue
  | SSizeof Type
  | SNoexpr
  deriving (Show, Eq)

Operators and types are reused from the AST, but we take this opportunity to define lvalues (values that can appear on the left hand side of an assignment, in our case variables, dereferences, and addresses) separately from the rest of the SExpr type. For struct access, we transform the field name from the textual name into an Int indicating the field's position in the struct.

data LValue = SDeref SExpr | SAccess LValue Int | SId Text
  deriving (Show, Eq)

For statements, we take the opportunity during semantic analysis to transform for and while loops into the corresponding combinations of if statements and do while loops, since these are easier to compile.

data SStatement =
    SExpr SExpr
  | SBlock [SStatement]
  | SReturn SExpr
  | SIf SExpr SStatement SStatement
  | SDoWhile SExpr SStatement
  deriving (Show, Eq)

Functions and whole programs remain basically unchanged, except that instead of functions having a list of statements in their body, they take a statement block instead.

data SFunction = SFunction
  { styp  :: Type
  , sname :: Text
  , sformals :: [Bind]
  , slocals :: [Bind]
  , sbody :: SStatement
  }
  deriving (Show, Eq)

type SProgram = ([Struct], [Bind], [SFunction])

(Full source for the SAST here.)

Capturing errors

The semantic checker is essentially a function of type Program -> Either SemantError SProgram. To capture possible errors, we define a big SemantError type and some Pretty instances for it.

module Microc.Semant.Error where

import           Microc.Ast
import           Data.Text                      ( Text )
import           Data.Text.Prettyprint.Doc

type Name = Text
-- Tells us if a binding appears in a function, a struct, or
-- is a global variable.
data BindingLoc = F Function | S Struct | Toplevel 
  deriving Show
data SemantError =
    IllegalBinding Name BindingKind VarKind BindingLoc
  | UndefinedSymbol Name SymbolKind Expr
  | TypeError { expected :: [Type], got :: Type, errorLoc :: Statement }
  | CastError { to :: Type, from :: Type, castLoc :: Statement }
  | ArgError { nExpected :: Int, nGot :: Int, callSite :: Expr }
  | Redeclaration Name
  | NoMain
  | AddressError Expr
  | AssignmentError { lhs :: Expr, rhs :: Expr }
  | AccessError { struct :: Expr, field :: Expr }
  | DeadCode Statement -- ^ For statements in a block following a return
  deriving (Show)

data BindingKind = Duplicate | Void deriving (Show)
data SymbolKind = Var | Func deriving (Show)

data VarKind = Global | Formal | Local | StructField
  deriving (Show, Eq, Ord)

There are many approaches to error handling in haskell, from using a centralized type for errors like we do here, using something like generic-lens, classy error handling, or just returning plain String.1 I'm not convinced there's a clearly "best" approach to this problem. The approach here certainly has its drawbacks in that maintaining it is very annoying and tedious. For a bigger project, I would hesitate to use a single type for all possible errors, but this is small enough that it's feasible.

(Full source for Microc.Semant.Error here, with pretty printing code.)

From AST to SAST

Now we're ready to begin writing the semantic checker in earnest.

{-# LANGUAGE FlexibleContexts #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
module Microc.Semant
  ( checkProgram
  )
where

import           Microc.Ast
import           Microc.Sast
import           Microc.Semant.Error
-- We'll discuss these modules later
import           Microc.Semant.Analysis
import           Microc.Utils

import qualified Data.Map                      as M
import           Control.Monad.State
import           Control.Monad.Except
import           Data.Maybe                     ( isJust )
import           Data.Text                      ( Text )
import           Data.List                      ( find
                                                , findIndex
                                                )

We'll need to maintain information about what variables and functions have already been declared and have it accessible to the semantic checking functions. Combined with the need to be able to return SemantError at any point during the checking process, we arrive at the following.

type Vars = M.Map (Text, VarKind) Type
type Funcs = M.Map Text Function
type Structs = [Struct]
data Env = Env { vars     :: Vars
               , funcs    :: Funcs
               , structs  :: Structs
               }

type Semant = ExceptT SemantError (State Env)

To check variable bindings, we need to know if they belong to a struct, function, or the top level, make sure that they are not of type void, and make sure that they don't conflict with any already-declared variables.

checkBinds :: VarKind -> BindingLoc -> [Bind] -> Semant [Bind]
checkBinds kind loc binds = do
  forM binds $ \case
    Bind TyVoid name -> throwError $ IllegalBinding name Void kind loc

    Bind ty     name -> do
      vars <- gets vars
      when (M.member (name, kind) vars)
        $ throwError (IllegalBinding name Duplicate kind loc)
      modify $ \env -> env { vars = M.insert (name, kind) ty vars }
      pure $ Bind ty name

Checking struct fields is very similar, except that we don't insert them into the general Env.

checkFields :: Struct -> Semant Struct
checkFields s@(Struct name fields) = do
  fields' <- foldM addField M.empty fields
  pure $ Struct name (M.elems fields') -- this doesn't preserve ordering
 where
  addField acc field@(Bind t name) = case t of
    TyVoid -> throwError $ IllegalBinding name Void StructField (S s)
    _      -> if M.member name acc
      then throwError (IllegalBinding name Duplicate StructField (S s))
      else pure $ M.insert name field acc

Next, we define some "built in" functions that we'll link in with all our executables. printf, malloc, and free are familiar, and printbig is a real C function that takes an int and prints it to the console in cool ASCII art form.2 Although we're only indicating that printf takes one argument, we'll handle it separately from all other functions, as we have no mechanism to define variadic functions in microc (although that could be a cool extension).

builtIns :: Funcs
builtIns = M.fromList $ map
  toFunc
  [ ("printf"  , [Pointer TyChar], TyVoid)
  , ("printbig", [TyInt]         , TyVoid)
  , ("malloc"  , [TyInt]         , Pointer TyVoid)
  , ("free"    , [Pointer TyVoid], TyVoid)
  ]
 where
  toFunc (name, tys, retty) =
    (name, Function retty name (map (flip Bind "dummy_var") tys) [] [])

Expressions

The code to check expressions is quite extensive. We'll get the self-explanatory cases out of the way first.

checkExpr :: Expr -> Semant SExpr
checkExpr expr = case expr of
  Literal  i -> pure (TyInt, SLiteral i)
  Fliteral f -> pure (TyFloat, SFliteral f)
  BoolLit  b -> pure (TyBool, SBoolLit b)
  CharLit  c -> pure (TyChar, SCharLit c)
  StrLit   s -> pure (Pointer TyChar, SStrLit s)
  Sizeof   t -> pure (TyInt, SSizeof t)
  Null       -> pure (Pointer TyVoid, SNull)
  Noexpr     -> pure (TyVoid, SNoexpr)

For variables, we look for matching local variables, then formal parameters of the enclosing function, then global variables for a match, and if we don't find anything, throw an error.

  Id s       -> do
    vars <- gets vars
    let foundVars = map (\kind -> M.lookup (s, kind) vars)
                        [Local, Formal, Global]
    case join $ find isJust foundVars of
      Nothing -> throwError $ UndefinedSymbol s Var expr
      Just ty -> pure (ty, LVal $ SId s)

When checking a binary operation, we first check the two sides and then proceed accordingly. It is useful to define a couple helper functions to assert that both sides of the binary operation have the same type and to check that certain subexpressions are booleans or sensible arithmetic.

  Binop op lhs rhs -> do
    lhs'@(t1, _) <- checkExpr lhs
    rhs'@(t2, _) <- checkExpr rhs

    let
      assertSym = unless (t1 == t2) $ throwError $
        TypeError [t1] t2 (Expr expr)
      checkArith = do
        unless (isNumeric t1) $ throwError $
          TypeError [TyInt, TyFloat] t1 (Expr expr)
        pure (t1, SBinop op lhs' rhs')

      checkBool = do
        unless (t1 == TyBool) $ throwError $
          TypeError [TyBool] t1 (Expr expr)
        pure (t1, SBinop op lhs' rhs')

(isNumeric is defined in a where clause at the of the function)

  where
    isNumeric = \case
      TyInt     -> True
      TyFloat   -> True
      TyChar    -> True
      Pointer _ -> True
      _         -> False

Addition is valid between two int's or float's and between a pointer and an int.

    case op of
      Add ->
        let sexpr = SBinop Add lhs' rhs'
        in
          case (t1, t2) of
            (Pointer t, TyInt    ) -> pure (Pointer t, sexpr)
            (TyInt    , Pointer t) -> pure (Pointer t, sexpr)
            (TyInt    , TyInt    ) -> pure (TyInt, sexpr)
            (TyFloat  , TyFloat  ) -> pure (TyFloat, sexpr)
            _ ->
              throwError $ TypeError
                [Pointer TyVoid, TyInt, TyFloat] t1 (Expr expr)

Subtraction is even more liberal than addition, as you can subtract pointers of the same underlying type to get an int.

      Sub ->
        let sexpr = SBinop Sub lhs' rhs'
        in
          case (t1, t2) of
            (Pointer t, TyInt     ) -> pure (Pointer t, sexpr)
            (TyInt    , Pointer t ) -> pure (Pointer t, sexpr)
            (Pointer t, Pointer t') -> if t == t'
              then pure (TyInt, sexpr)
              else throwError $
                     TypeError [Pointer t'] (Pointer t) (Expr expr)
            (TyInt  , TyInt  ) -> pure (TyInt, sexpr)
            (TyFloat, TyFloat) -> pure (TyFloat, sexpr)
            _ -> throwError $ TypeError
              [Pointer TyVoid, TyInt, TyFloat] t1 (Expr expr)

Most of the other mathematical operators are much simpler to check, as they just require that both of their operands have the same type.3

      Mult   -> assertSym >> checkArith
      Div    -> assertSym >> checkArith
      BitAnd -> assertSym >> checkArith
      BitOr  -> assertSym >> checkArith
      And    -> assertSym >> checkBool
      Or     -> assertSym >> checkBool

Our ** operator will work for float ** float, float ** int, and int ** int. The first two will be compiled to the llvm intrinsics pow and powi respectively, and for the int ** int case we'll write the assembly ourselves.

      Power  -> case (t1, t2) of
        (TyFloat, TyFloat) ->
          pure (TyFloat, SCall "llvm.pow.f64" [lhs', rhs'])
        (TyFloat, TyInt  ) -> 
          pure (TyFloat, SCall "llvm.powi.i32" [lhs', rhs'])
        -- Implement this case directly in llvm
        (TyInt  , TyInt  ) -> pure (TyInt, SBinop Power lhs' rhs')
        _ -> throwError $ TypeError [TyFloat, TyInt] t1 (Expr expr)

The remaining binary operators are all relational. When comparing the null pointer to another pointer, we cast it to the type of the other pointer (but this is the only time that we do this). All other relational operators are numeric.

      relational -> case (snd lhs', snd rhs') of
        (SNull, _    ) -> checkExpr (Binop relational (Cast t1 lhs) rhs)
        (_    , SNull) -> checkExpr (Binop relational lhs (Cast t1 rhs))
        _              -> do
          assertSym
          unless (isNumeric t1)
            $ throwError (TypeError [TyInt, TyFloat] t1 (Expr expr))
          pure (TyBool, SBinop op lhs' rhs')

Unary operations, by contrast with binary operators, are not complicated.

  Unop op e -> do
    e'@(ty, _) <- checkExpr e
    case op of
      Neg -> do
        unless (isNumeric ty)
          $ throwError (TypeError [TyInt, TyFloat] ty (Expr expr))
        pure (ty, SUnop Neg e')
      Not -> do
        unless (ty == TyBool) $ throwError $
          TypeError [TyBool] ty (Expr expr)
        pure (ty, SUnop Not e')

Taking the address of an expression is valid only on LVal's and dereferencing is only valid for pointer types.

  Addr e -> do
    (t, e') <- checkExpr e
    case e' of
      LVal l -> pure (Pointer t, SAddr l)
      _      -> throwError (AddressError e)

  Deref e -> do
    (ty, e') <- checkExpr e
    case ty of
      Pointer t -> pure (t, LVal $ SDeref (ty, e'))
      _         -> throwError $ TypeError
        [Pointer TyVoid, Pointer TyInt, Pointer TyFloat] ty (Expr expr)

Next, we handle function calls. printf is a special case, as it takes an arbitrary number of arguments. We don't really make any effort to check that the format string matches the rest of the arguments. All we check is that all of the arguments to printf are well-formed expressions.

  Call "printf" es -> do
    es' <- mapM checkExpr es
    let (formatStr, _) = head es'
    unless (formatStr == Pointer TyChar)
      $ throwError (TypeError [Pointer TyChar] formatStr (Expr expr))
    pure (TyVoid, SCall "printf" es')

For user-defined functions and the other built-ins, we look up the name of the function in the environment, throw an error if it doesn't exist, then check that the number and types of all of the parameters match the declaration.

  Call s es -> do
    funcs <- gets funcs
    case M.lookup s funcs of
      Nothing -> throwError $ UndefinedSymbol s Func expr
      Just f  -> do
        es' <- mapM checkExpr es
        -- Check that the correct number of arguments was provided
        let nFormals = length (formals f)
            nActuals = length es
        unless (nFormals == nActuals) $ throwError
          (ArgError nFormals nActuals expr)
        -- Check that types of arguments match
        forM_ (zip (map fst es') (map bindType (formals f)))
          $ \(callSite, defSite) ->
              unless (callSite == defSite) $ throwError $ TypeError
                { expected = [defSite]
                , got      = callSite
                , errorLoc = Expr expr
                }
        pure (typ f, SCall s es')

For explicit type casts, we allow casts between different pointer types, between int's and pointers, and from int's to float's.4

  Cast t' e -> do
    e'@(t, _) <- checkExpr e
    case (t', t) of
      (Pointer _, Pointer _) -> pure (t', SCast t' e')
      (Pointer _, TyInt    ) -> pure (t', SCast t' e')
      (TyInt    , Pointer _) -> pure (t', SCast t' e')
      (TyFloat  , TyInt    ) -> pure (t', SCast t' e')
      _                      -> throwError $ CastError t' t (Expr expr)

When accessing structs, we first check that the right hand side is a variable5 and that the left hand side is an LVal.

  Access e field -> do
    fieldName <- case field of
      Id f -> pure f
      _    -> throwError (AccessError field e)

    (t, e') <- checkExpr e
    lval    <- case e' of
      LVal l' -> pure l'
      _       -> throwError (AccessError e field)

Then, we check that the type of the thing being accessed is indeed a struct and that it has been declared.

    (Struct _ fields) <- case t of
      TyStruct name' -> do
        ss <- gets structs
        case find (\(Struct n _) -> n == name') ss of
          Nothing -> throwError
            (TypeError [TyStruct "a_struct"] t (Expr expr))
          Just s  -> pure s
      _ -> throwError (TypeError [TyStruct "a_struct"] t (Expr expr))

Finally, we check that the field being accessed exists on the struct declaration and return its position.6

    f <- case findIndex (\(Bind _ f) -> f == fieldName) fields of
      Nothing -> throwError (AccessError e field)
      Just i  -> pure i

    pure (bindType (fields !! f), LVal $ SAccess lval f)

The last node on our expression type is for assignments. We check the left and right hand sides and assert that the left hand side is an LVal.

  Assign lhs rhs -> do
    lhs'@(t1, _) <- checkExpr lhs
    rhs'@(t2, _) <- checkExpr rhs
    lval         <- case snd lhs' of
      LVal e -> pure e
      _      -> throwError $ AssignmentError lhs rhs

Then, we have to handle the special case of assigning NULL to a pointer variable by casting it correctly. For all other variables, we simply assert that the left and right hand sides have equal types.

    case snd rhs' of
      SNull -> checkExpr (Assign lhs (Cast t1 rhs))
      _ -> do
        unless (t1 == t2) $ throwError $ TypeError [t1] t2 (Expr expr)
        pure (t2, SAssign lval rhs')

Statements

Now we come to checkStatement. In addition to the current Statement, we need to know which function we're in so that we can check return types.

The Expr node is easy.

checkStatement :: Function -> Statement -> Semant SStatement
checkStatement func stmt = case stmt of
  Expr e           -> SExpr <$> checkExpr e

For If's, we just check that the predicate is a boolean expression and then recursively descend into the two branches.

  If pred cons alt -> do
    pred'@(ty, _) <- checkExpr pred
    unless (ty == TyBool) $ throwError $ TypeError [TyBool] ty stmt
    SIf pred' <$> checkStatement func cons <*> checkStatement func alt

While loops are very similar, except that we transform

while (cond) { action; }

into

if (cond) { do { action; } while (cond); }

for the aforementioned ease of compilation.

  While cond action -> do
    cond'@(ty, _) <- checkExpr cond
    unless (ty == TyBool) $ throwError $ TypeError [TyBool] ty stmt
    action' <- checkStatement func action
    pure $ SIf cond' (SDoWhile cond' action') (SBlock [])

For loops undergo the transformation from

for (init; cond; inc) {
  action;
}

to

init;
if (cond) {
  do {
    action;
    inc;
  } while (cond)
}
  For init cond inc action -> do
    cond'@(ty, _) <- checkExpr cond
    unless (ty == TyBool) $ throwError $ TypeError [TyBool] ty stmt
    init'   <- checkExpr init
    inc'    <- checkExpr inc
    action' <- checkStatement func action
    pure $ SBlock
      [ SExpr init'
      , SIf cond'
            (SDoWhile cond' (SBlock [action', SExpr inc'])) 
            (SBlock [])
      ]

For returns, check that the expression type matches the declared return type.

  Return expr -> do
    e@(ty, _) <- checkExpr expr
    unless (ty == typ func) $ throwError $ 
      TypeError [typ func] ty stmt
    pure $ SReturn e

Blocks are interesting. It's technically allowed to write code with arbitrary nested statement blocks:

{
  printf("Hello"); {
    printf(" from this block\n");
  }
}

However, this makes some types of analyses harder, so we flatten such blocks before proceeding with checking. We also want to ensure that dead code, i.e. code after a return gets flagged as an error. Otherwise, we simply recurse into the block as usual.

  Block sl -> do
    let flattened = flatten sl
    unless (nothingFollowsRet flattened) $ throwError (DeadCode stmt)
    SBlock <$> mapM (checkStatement func) flattened
   where
    flatten []             = []
    flatten (Block s : ss) = flatten (s ++ ss)
    flatten (s       : ss) = s : flatten ss

    nothingFollowsRet []         = True
    nothingFollowsRet [Return _] = True
    nothingFollowsRet (s : ss  ) = case s of
      Return _ -> False
      _        -> nothingFollowsRet ss

Functions

For functions, we first check for naming conflicts, then add the formal parameters and local variables to the environment locally for the duration of checking just this function.7

checkFunction :: Function -> Semant SFunction
checkFunction func = do
  -- add the fname to the table and check for conflicts
  funcs <- gets funcs
  unless (M.notMember (name func) funcs) $ throwError $
    Redeclaration (name func)
  -- add this func to symbol table
  modify $ \env -> env { funcs = M.insert (name func) func funcs }

  (formals', locals', body') <- locally $ liftM3
    (,,)
    (checkBinds Formal (F func) (formals func))
    (checkBinds Local (F func) (locals func))
    (checkStatement func (Block $ body func))

When checking the body of the function, we will be quite a bit stricter than most C compilers and assert that unless the function is marked void, all paths in control flow must end in a return.

  case body' of
    SBlock body'' -> do
      unless (typ func == TyVoid || validate (genCFG body''))
        $ throwError (TypeError [typ func] TyVoid (Block $ body func))

      pure $ SFunction { styp     = typ func
                       , sname    = name func
                       , sformals = formals'
                       , slocals  = locals'
                       , sbody    = SBlock body''
                       }
    _ -> error "Internal error - block didn't become a block?"

To accomplish this, we'll create a new module, Microc.Semant.Analysis and define a very simple control flow graph structure as either an empty sequence, meaning that the branch is over and there are no more statements, a sequence of statements, or a branch point.

module Microc.Semant.Analysis where

import Microc.Sast

-- | True if statement is a return
data CFG = Empty | Seq Bool CFG | Branch CFG CFG

To create the CFG from a list of statements, we simply recurse.

-- | By this point, the dead code invariant will have been checked
genCFG :: [SStatement] -> CFG
genCFG [] = Empty
genCFG (s:ss) = case s of
    SReturn _ -> Seq True (genCFG ss)
    SIf _ cons alt -> Branch (genCFG (cons : ss)) (genCFG (alt:ss))
    SDoWhile _ stmt -> Seq False (genCFG (stmt:ss))
    SBlock stmts -> genCFG (stmts <> ss)
    _ -> Seq False (genCFG ss)

A valid CFG is one in which all branches end in return s.

-- | Traverses cfg and returns true if all leaves are true
validate :: CFG -> Bool
validate = \case
  Empty -> False
  Seq b Empty -> b
  Seq _ rest -> validate rest
  Branch left right -> validate left && validate right

Note that this can be unnecessarily stringent, as there are some cases such as

int main() {
  if (true) {
    return 0;  
  }
}

that our compiler will reject, even though it's clear that control flow will never pass the if statement. However, I prefer this to being overly permissive.

Wrapping up Semant

Finally, we can write checkProgram, which will unwrap the Semant monad, insert the built in functions into the environment, then run checkers for struct declarations, global variables, and function definitions. Note that we reject all programs which don't define a "main" function.

checkProgram :: Program -> Either SemantError SProgram
checkProgram program =
  evalState (runExceptT (checkProgram' program)) baseEnv
 where
  baseEnv = Env { structs = [], vars = M.empty, funcs = builtIns }

  checkProgram' :: Program -> Semant SProgram
  checkProgram' (Program structs binds funcs) = do
    structs' <- mapM checkFields structs
    modify $ \e -> e { structs = structs' }
    globals <- checkBinds Global Toplevel binds
    funcs'  <- mapM checkFunction funcs
    case find (\f -> sname f == "main") funcs' of
      Nothing -> throwError NoMain
      Just _  -> pure (structs', globals, funcs')

(Full source for Semant.hs here.)

More wiring

In part 1, we stubbed out Main.hs. Now that semantic checking works (hopefully), we'll go back to runOpts and add a case to pretty print the Sast. As this is mainly just for debugging, we won't bother hand writing a pretty-printer for the Sast. Instead, we'll use pretty-simple to get nice output for the auto-derived Show instance.

runOpts :: Options -> IO ()
runOpts (Options action infile ptype) = do
  program <- T.readFile infile
  let parseTree = case ptype of
        Combinator -> runParser programP infile program
        Generator  -> Right $ parse . alexScanTokens $ T.unpack program
  case parseTree of
    Left  err -> putStrLn $ errorBundlePretty err
    Right ast -> case action of
      Ast -> putDoc $ pretty ast <> "\n"
      _   -> case checkProgram ast of
        Left err -> putDoc $ pretty err <> "\n"
        Right sast -> case action of
          Sast -> pPrint sast
          _ -> error "Haven't written codegen yet"

Testing

Now that we're done with the parts of the compiler that deal with rejecting invalid programs, we should test that functionality. We'll construct a new TestTree that looks in the tests/fail directory and asserts that running the compiler on each ".mc" file produces the same error message as specified by the corresponding ".golden" file.

failing :: IO TestTree
failing = do
  mcFiles <- findByExtension [".mc"] "tests/fail"
  return $ testGroup
    "fail"
    [ goldenVsString (takeBaseName mcFile) outfile (cs <$> runFile mcFile)
    | mcFile <- mcFiles
    , let outfile = replaceExtension mcFile ".golden"
    ]

-- | Given a microc file, attempt to compile and execute it and read the
-- results to be compared with what should be the correct output
runFile :: FilePath -> IO Text
runFile infile = do
  program <- T.readFile infile
  let parseTree = runParser programP (cs infile) program
  case parseTree of
    Left  e   -> return . cs $ errorBundlePretty e
    Right ast -> case checkProgram ast of
      Left err -> return . renderStrict $
        layoutPretty defaultLayoutOptions (pretty err)
      Right sast -> error "Codegen not yet implemented"

We'll also refactor main to account for these new tests.

main :: IO ()
main = defaultMain =<< goldenTests

-- | All of the test cases
-- General structure taken from
-- https://ro-che.info/articles/2017-12-04-golden-tests
goldenTests :: IO TestTree
goldenTests = testGroup "all" <$> sequence [failing, parsing]

Now, if we want to run only the parsing tests, or only the semantic tests, we can, inside nix-shell, run cabal run testall -- -p "parsing" or cabal run testall -- -p "fail" . Instead of specifying what error messages we expect to output ahead of time in each golden file, the first time we run the tests, we'll use the workflow proposed by the Tasty author, Roman Cheplyaka, running cabal run testall -- --accept, and commit the generated golden files after manually inspecting them. This is where using golden tests really shines, as writing out all the error messages beforehand in a file while keeping the correct formatting is too horrible to contemplate.

Finally, the semantic checker is done! All that's left is code generation, which we'll tackle in part 3. Thanks to everyone for reading so far, and especially to Théophile Choutri, who's been sponsoring this series and has set up a microfund for it.


1

As of April 2020, GHC does this, although there's a proposal to change to a different approach.

2

printbig originally appeared in the ocaml microc compiler that this is based on.

3

There's an or-patterns extension proposal that's been sitting around for years that, if accepted, would reduce a lot of duplication, but sadly it looks like it's going nowhere fast.

4

There's nothing stopping us from allowing more casts, in principle, but they're usually quite dangerous and invoke undefined behavior.

5

This could actually be done in the parser but the way that megaparsec handles expression parsing makes it awkward/impossible to assign asymmetric binary operators a precedence, so we have to pretend that . and -> take Expr's on both sides as opposed to an Expr and Text and enforce the invariant at check time.

6

Notice that we've done a lot of var <- case thing of Nothing -> error; Just ok -> pure ok. We could have continued along in the Just branch instead but I find that this leads to an undesirable amount of indentation that quickly becomes unwieldy.

7

locally is defined in Microc.Utils. It prevents us from polluting the global environment with local variable names.

module Microc.Utils where
import Control.Monad.State
locally :: MonadState s m => m a -> m a
locally computation = do
  oldState <- get
  result   <- computation
  put oldState
  return result