# HG changeset patch # User Michael Pavone # Date 1548737695 28800 # Node ID 89932fd29abd3d7494a895b2339d8126e89bbcfc # Parent 49a52c737bf0b87d6333d00cc6bc09d4965eabbe First stab at carry and half-carry calculation in CPU DSL diff -r 49a52c737bf0 -r 89932fd29abd cpu_dsl.py --- a/cpu_dsl.py Mon Jan 28 19:24:04 2019 -0800 +++ b/cpu_dsl.py Mon Jan 28 20:54:55 2019 -0800 @@ -20,6 +20,14 @@ self.addOp(NormalOp(parts)) return self + def processOps(self, prog, fieldVals, output, otype, oplist): + for i in range(0, len(oplist)): + if i + 1 < len(oplist) and oplist[i+1].op == 'update_flags': + flagUpdates, _ = prog.flags.parseFlagUpdate(oplist[i+1].params[0]) + else: + flagUpdates = None + oplist[i].generate(prog, self, fieldVals, output, otype, flagUpdates) + def resolveLocal(self, name): return None @@ -121,8 +129,7 @@ output.append('\n\tuint{sz}_t {name};'.format(sz=self.locals[var], name=var)) self.newLocals = [] fieldVals,_ = self.getFieldVals(value) - for op in self.implementation: - op.generate(prog, self, fieldVals, output, otype) + self.processOps(prog, fieldVals, output, otype, self.implementation) begin = '\nvoid ' + self.generateName(value) + '(' + prog.context_type + ' *context)\n{' if prog.needFlagCoalesce: begin += prog.flags.coalesceFlags(prog, otype) @@ -189,8 +196,7 @@ for name in self.locals: size = self.locals[name] output.append('\n\tuint{size}_t {sub}_{local};'.format(size=size, sub=self.name, local=name)) - for op in self.implementation: - op.generate(prog, self, argValues, output, otype) + self.processOps(prog, argValues, output, otype, self.implementation) prog.popScope() def __str__(self): @@ -209,15 +215,36 @@ self.impls = {} self.outOp = () def cBinaryOperator(self, op): - def _impl(prog, params): + def _impl(prog, params, rawParams, flagUpdates): if op == '-': a = params[1] b = params[0] else: a = params[0] b = params[1] - return '\n\t{dst} = {a} {op} {b};'.format( - dst = params[2], a = a, b = b, op = op + needsCarry = needsOflow = needsHalf = False + if flagUpdates: + for flag in flagUpdates: + calc = prog.flags.flagCalc[flag] + if calc == 'carry': + needsCarry = True + elif calc == 'half-carry': + needsHalf = True + elif calc == 'overflow': + needsOflow = True + decl = '' + if needsCarry or needsOflow or needsHalf: + size = prog.paramSize(rawParams[2]) + if needsCarry: + size *= 2 + decl,name = prog.getTemp(size) + dst = prog.carryFlowDst = name + prog.lastA = a + prog.lastB = b + else: + dst = params[2] + return decl + '\n\t{dst} = {a} {op} {b};'.format( + dst = dst, a = a, b = b, op = op ) self.impls['c'] = _impl self.outOp = (2,) @@ -244,11 +271,13 @@ return not self.evalFun is None def numArgs(self): return self.evalFun.__code__.co_argcount - def generate(self, otype, prog, params, rawParams): + def generate(self, otype, prog, params, rawParams, flagUpdates): if self.impls[otype].__code__.co_argcount == 2: return self.impls[otype](prog, params) + elif self.impls[otype].__code__.co_argcount == 3: + return self.impls[otype](prog, params, rawParams) else: - return self.impls[otype](prog, params, rawParams) + return self.impls[otype](prog, params, rawParams, flagUpdates) def _xchgCImpl(prog, params, rawParams): @@ -264,36 +293,26 @@ return '\n\timpl_{tbl}[{op}](context);'.format(tbl = table, op = params[0]) def _updateFlagsCImpl(prog, params, rawParams): - i = 0 - last = '' - autoUpdate = set() - explicit = {} - for c in params[0]: - if c.isdigit(): - if last.isalpha(): - num = int(c) - if num > 1: - raise Exception(c + ' is not a valid digit for update_flags') - explicit[last] = num - last = c - else: - raise Exception('Digit must follow flag letter in update_flags') - else: - if last.isalpha(): - autoUpdate.add(last) - last = c - if last.isalpha(): - autoUpdate.add(last) + autoUpdate, explicit = prog.flags.parseFlagUpdate(params[0]) output = [] #TODO: handle autoUpdate flags for flag in autoUpdate: calc = prog.flags.flagCalc[flag] calc,_,resultBit = calc.partition('-') - lastDst = prog.resolveParam(prog.lastDst, None, {}) + if prog.carryFlowDst: + lastDst = prog.carryFlowDst + else: + lastDst = prog.resolveParam(prog.lastDst, None, {}) storage = prog.flags.getStorage(flag) - if calc == 'bit' or calc == 'sign': + if calc == 'bit' or calc == 'sign' or calc == 'carry' or calc == 'half': + myRes = lastDst if calc == 'sign': resultBit = prog.paramSize(prog.lastDst) - 1 + elif calc == 'carry': + resultBit = prog.paramSize(prog.lastDst) + elif calc == 'half': + resultBit = 4 + myRes = '({a} ^ {b} ^ {res})'.format(a = prog.lastA, b = prog.lastB, res = lastDst) else: resultBit = int(resultBit) if type(storage) is tuple: @@ -302,7 +321,7 @@ if storageBit == resultBit: #TODO: optimize this case output.append('\n\t{reg} = ({reg} & ~{mask}U) | ({res} & {mask}U);'.format( - reg = reg, mask = 1 << resultBit, res = lastDst + reg = reg, mask = 1 << resultBit, res = myRes )) else: if resultBit > storageBit: @@ -312,11 +331,11 @@ op = '<<' shift = storageBit - resultBit output.append('\n\t{reg} = ({reg} & ~{mask}U) | ({res} {op} {shift}U & {mask}U);'.format( - reg = reg, mask = 1 << storageBit, res = lastDst, op = op, shift = shift + reg = reg, mask = 1 << storageBit, res = myRes, op = op, shift = shift )) else: reg = prog.resolveParam(storage, None, {}) - output.append('\n\t{reg} = {res} & {mask}U;'.format(reg=reg, res=lastDst, mask = 1 << resultBit)) + output.append('\n\t{reg} = {res} & {mask}U;'.format(reg=reg, res=myRes, mask = 1 << resultBit)) elif calc == 'zero': if type(storage) is tuple: reg,storageBit = storage @@ -328,15 +347,16 @@ reg = prog.resolveParam(storage, None, {}) output.append('\n\t{reg} = {res} == 0;'.format( reg = reg, res = lastDst - )) - elif calc == 'half-carry': - pass - elif calc == 'carry': - pass + )) elif calc == 'overflow': pass elif calc == 'parity': pass + else: + raise Exception('Unknown flag calc type: ' + calc) + if prog.carryFlowDst: + output.append('\n\t{dst} = {tmpdst};'.format(dst = prog.resolveParam(prog.lastDst, None, {}), tmpdst = prog.carryFlowDst)) + prog.carryFlowDst = None #TODO: combine explicit flags targeting the same storage location for flag in explicit: location = prog.flags.getStorage(flag) @@ -458,9 +478,9 @@ self.op = parts[0] self.params = parts[1:] - def generate(self, prog, parent, fieldVals, output, otype): + def generate(self, prog, parent, fieldVals, output, otype, flagUpdates): procParams = [] - allParamsConst = True + allParamsConst = flagUpdates is None opDef = _opMap.get(self.op) for param in self.params: allowConst = (self.op in prog.subroutines or len(procParams) != len(self.params) - 1) and param in parent.regValues @@ -502,9 +522,9 @@ if prog.isReg(dst): shortProc = (procParams[0], procParams[-1]) shortParams = (self.params[0], self.params[-1]) - output.append(_opMap['mov'].generate(otype, prog, shortProc, shortParams)) + output.append(_opMap['mov'].generate(otype, prog, shortProc, shortParams, None)) else: - output.append(opDef.generate(otype, prog, procParams, self.params)) + output.append(opDef.generate(otype, prog, procParams, self.params, flagUpdates)) elif self.op in prog.subroutines: prog.subroutines[self.op].inline(prog, procParams, output, otype, parent) else: @@ -558,7 +578,7 @@ return self.current_locals[name] return self.parent.localSize(name) - def generate(self, prog, parent, fieldVals, output, otype): + def generate(self, prog, parent, fieldVals, output, otype, flagUpdates): prog.pushScope(self) param = prog.resolveParam(self.param, parent, fieldVals) if type(param) is int: @@ -568,16 +588,14 @@ output.append('\n\t{') for local in self.case_locals[param]: output.append('\n\tuint{0}_t {1};'.format(self.case_locals[param][local], local)) - for op in self.cases[param]: - op.generate(prog, self, fieldVals, output, otype) + self.processOps(prog, fieldVals, output, otype, self.cases[param]) output.append('\n\t}') elif self.default: self.current_locals = self.default_locals output.append('\n\t{') for local in self.default_locals: output.append('\n\tuint{0}_t {1};'.format(self.default[local], local)) - for op in self.default: - op.generate(prog, self, fieldVals, output, otype) + self.processOps(prog, fieldVals, output, otype, self.default) output.append('\n\t}') else: output.append('\n\tswitch(' + param + ')') @@ -588,8 +606,7 @@ output.append('\n\tcase {0}U: '.format(case) + '{') for local in self.case_locals[case]: output.append('\n\tuint{0}_t {1};'.format(self.case_locals[case][local], local)) - for op in self.cases[case]: - op.generate(prog, self, fieldVals, output, otype) + self.processOps(prog, fieldVals, output, otype, self.cases[case]) output.append('\n\tbreak;') output.append('\n\t}') if self.default: @@ -598,8 +615,7 @@ output.append('\n\tdefault: {') for local in self.default_locals: output.append('\n\tuint{0}_t {1};'.format(self.default_locals[local], local)) - for op in self.default: - op.generate(prog, self, fieldVals, output, otype) + self.processOps(prog, fieldVals, output, otype, self.default) output.append('\n\t}') prog.popScope() @@ -666,15 +682,13 @@ self.curLocals = self.locals for local in self.locals: output.append('\n\tuint{sz}_t {nm};'.format(sz=self.locals[local], nm=local)) - for op in self.body: - op.generate(prog, self, fieldVals, output, otype) + self.processOps(prog, fieldVals, output, otype, self.body) def _genFalseBody(self, prog, fieldVals, output, otype): self.curLocals = self.elseLocals for local in self.elseLocals: output.append('\n\tuint{sz}_t {nm};'.format(sz=self.elseLocals[local], nm=local)) - for op in self.elseBody: - op.generate(prog, self, fieldVals, output, otype) + self.processOps(prog, fieldVals, output, otype, self.elsebody) def _genConstParam(self, param, prog, fieldVals, output, otype): if param: @@ -682,7 +696,7 @@ else: self._genFalseBody(prog, fieldVals, output, otype) - def generate(self, prog, parent, fieldVals, output, otype): + def generate(self, prog, parent, fieldVals, output, otype, flagUpdates): self.regValues = parent.regValues try: self._genConstParam(prog.checkBool(self.cond), prog, fieldVals, output, otype) @@ -829,6 +843,28 @@ else: return loc + def parseFlagUpdate(self, flagString): + last = '' + autoUpdate = set() + explicit = {} + for c in flagString: + if c.isdigit(): + if last.isalpha(): + num = int(c) + if num > 1: + raise Exception(c + ' is not a valid digit for update_flags') + explicit[last] = num + last = c + else: + raise Exception('Digit must follow flag letter in update_flags') + else: + if last.isalpha(): + autoUpdate.add(last) + last = c + if last.isalpha(): + autoUpdate.add(last) + return (autoUpdate, explicit) + def disperseFlags(self, prog, otype): bitToFlag = [None] * (self.maxBit+1) src = prog.resolveReg(self.flagReg, None, {}) @@ -949,6 +985,9 @@ self.scopes = [] self.currentScope = None self.lastOp = None + self.carryFlowDst = None + self.lastA = None + self.lastB = None def __str__(self): pieces = []