comparison cpu_dsl.py @ 1704:89932fd29abd

First stab at carry and half-carry calculation in CPU DSL
author Michael Pavone <pavone@retrodev.com>
date Mon, 28 Jan 2019 20:54:55 -0800
parents 49a52c737bf0
children 9ab64ef5cba0
comparison
equal deleted inserted replaced
1703:49a52c737bf0 1704:89932fd29abd
17 elif parts[0] == 'end': 17 elif parts[0] == 'end':
18 raise Exception('end is only allowed inside a switch or if block') 18 raise Exception('end is only allowed inside a switch or if block')
19 else: 19 else:
20 self.addOp(NormalOp(parts)) 20 self.addOp(NormalOp(parts))
21 return self 21 return self
22
23 def processOps(self, prog, fieldVals, output, otype, oplist):
24 for i in range(0, len(oplist)):
25 if i + 1 < len(oplist) and oplist[i+1].op == 'update_flags':
26 flagUpdates, _ = prog.flags.parseFlagUpdate(oplist[i+1].params[0])
27 else:
28 flagUpdates = None
29 oplist[i].generate(prog, self, fieldVals, output, otype, flagUpdates)
22 30
23 def resolveLocal(self, name): 31 def resolveLocal(self, name):
24 return None 32 return None
25 33
26 class ChildBlock(Block): 34 class ChildBlock(Block):
119 self.regValues = {} 127 self.regValues = {}
120 for var in self.locals: 128 for var in self.locals:
121 output.append('\n\tuint{sz}_t {name};'.format(sz=self.locals[var], name=var)) 129 output.append('\n\tuint{sz}_t {name};'.format(sz=self.locals[var], name=var))
122 self.newLocals = [] 130 self.newLocals = []
123 fieldVals,_ = self.getFieldVals(value) 131 fieldVals,_ = self.getFieldVals(value)
124 for op in self.implementation: 132 self.processOps(prog, fieldVals, output, otype, self.implementation)
125 op.generate(prog, self, fieldVals, output, otype)
126 begin = '\nvoid ' + self.generateName(value) + '(' + prog.context_type + ' *context)\n{' 133 begin = '\nvoid ' + self.generateName(value) + '(' + prog.context_type + ' *context)\n{'
127 if prog.needFlagCoalesce: 134 if prog.needFlagCoalesce:
128 begin += prog.flags.coalesceFlags(prog, otype) 135 begin += prog.flags.coalesceFlags(prog, otype)
129 if prog.needFlagDisperse: 136 if prog.needFlagDisperse:
130 output.append(prog.flags.disperseFlags(prog, otype)) 137 output.append(prog.flags.disperseFlags(prog, otype))
187 argValues[name] = params[i] 194 argValues[name] = params[i]
188 i += 1 195 i += 1
189 for name in self.locals: 196 for name in self.locals:
190 size = self.locals[name] 197 size = self.locals[name]
191 output.append('\n\tuint{size}_t {sub}_{local};'.format(size=size, sub=self.name, local=name)) 198 output.append('\n\tuint{size}_t {sub}_{local};'.format(size=size, sub=self.name, local=name))
192 for op in self.implementation: 199 self.processOps(prog, argValues, output, otype, self.implementation)
193 op.generate(prog, self, argValues, output, otype)
194 prog.popScope() 200 prog.popScope()
195 201
196 def __str__(self): 202 def __str__(self):
197 pieces = [self.name] 203 pieces = [self.name]
198 for name,size in self.args: 204 for name,size in self.args:
207 def __init__(self, evalFun = None): 213 def __init__(self, evalFun = None):
208 self.evalFun = evalFun 214 self.evalFun = evalFun
209 self.impls = {} 215 self.impls = {}
210 self.outOp = () 216 self.outOp = ()
211 def cBinaryOperator(self, op): 217 def cBinaryOperator(self, op):
212 def _impl(prog, params): 218 def _impl(prog, params, rawParams, flagUpdates):
213 if op == '-': 219 if op == '-':
214 a = params[1] 220 a = params[1]
215 b = params[0] 221 b = params[0]
216 else: 222 else:
217 a = params[0] 223 a = params[0]
218 b = params[1] 224 b = params[1]
219 return '\n\t{dst} = {a} {op} {b};'.format( 225 needsCarry = needsOflow = needsHalf = False
220 dst = params[2], a = a, b = b, op = op 226 if flagUpdates:
227 for flag in flagUpdates:
228 calc = prog.flags.flagCalc[flag]
229 if calc == 'carry':
230 needsCarry = True
231 elif calc == 'half-carry':
232 needsHalf = True
233 elif calc == 'overflow':
234 needsOflow = True
235 decl = ''
236 if needsCarry or needsOflow or needsHalf:
237 size = prog.paramSize(rawParams[2])
238 if needsCarry:
239 size *= 2
240 decl,name = prog.getTemp(size)
241 dst = prog.carryFlowDst = name
242 prog.lastA = a
243 prog.lastB = b
244 else:
245 dst = params[2]
246 return decl + '\n\t{dst} = {a} {op} {b};'.format(
247 dst = dst, a = a, b = b, op = op
221 ) 248 )
222 self.impls['c'] = _impl 249 self.impls['c'] = _impl
223 self.outOp = (2,) 250 self.outOp = (2,)
224 return self 251 return self
225 def cUnaryOperator(self, op): 252 def cUnaryOperator(self, op):
242 return self.evalFun(*params) 269 return self.evalFun(*params)
243 def canEval(self): 270 def canEval(self):
244 return not self.evalFun is None 271 return not self.evalFun is None
245 def numArgs(self): 272 def numArgs(self):
246 return self.evalFun.__code__.co_argcount 273 return self.evalFun.__code__.co_argcount
247 def generate(self, otype, prog, params, rawParams): 274 def generate(self, otype, prog, params, rawParams, flagUpdates):
248 if self.impls[otype].__code__.co_argcount == 2: 275 if self.impls[otype].__code__.co_argcount == 2:
249 return self.impls[otype](prog, params) 276 return self.impls[otype](prog, params)
250 else: 277 elif self.impls[otype].__code__.co_argcount == 3:
251 return self.impls[otype](prog, params, rawParams) 278 return self.impls[otype](prog, params, rawParams)
279 else:
280 return self.impls[otype](prog, params, rawParams, flagUpdates)
252 281
253 282
254 def _xchgCImpl(prog, params, rawParams): 283 def _xchgCImpl(prog, params, rawParams):
255 size = prog.paramSize(rawParams[0]) 284 size = prog.paramSize(rawParams[0])
256 decl,name = prog.getTemp(size) 285 decl,name = prog.getTemp(size)
262 else: 291 else:
263 table = params[1] 292 table = params[1]
264 return '\n\timpl_{tbl}[{op}](context);'.format(tbl = table, op = params[0]) 293 return '\n\timpl_{tbl}[{op}](context);'.format(tbl = table, op = params[0])
265 294
266 def _updateFlagsCImpl(prog, params, rawParams): 295 def _updateFlagsCImpl(prog, params, rawParams):
267 i = 0 296 autoUpdate, explicit = prog.flags.parseFlagUpdate(params[0])
268 last = ''
269 autoUpdate = set()
270 explicit = {}
271 for c in params[0]:
272 if c.isdigit():
273 if last.isalpha():
274 num = int(c)
275 if num > 1:
276 raise Exception(c + ' is not a valid digit for update_flags')
277 explicit[last] = num
278 last = c
279 else:
280 raise Exception('Digit must follow flag letter in update_flags')
281 else:
282 if last.isalpha():
283 autoUpdate.add(last)
284 last = c
285 if last.isalpha():
286 autoUpdate.add(last)
287 output = [] 297 output = []
288 #TODO: handle autoUpdate flags 298 #TODO: handle autoUpdate flags
289 for flag in autoUpdate: 299 for flag in autoUpdate:
290 calc = prog.flags.flagCalc[flag] 300 calc = prog.flags.flagCalc[flag]
291 calc,_,resultBit = calc.partition('-') 301 calc,_,resultBit = calc.partition('-')
292 lastDst = prog.resolveParam(prog.lastDst, None, {}) 302 if prog.carryFlowDst:
303 lastDst = prog.carryFlowDst
304 else:
305 lastDst = prog.resolveParam(prog.lastDst, None, {})
293 storage = prog.flags.getStorage(flag) 306 storage = prog.flags.getStorage(flag)
294 if calc == 'bit' or calc == 'sign': 307 if calc == 'bit' or calc == 'sign' or calc == 'carry' or calc == 'half':
308 myRes = lastDst
295 if calc == 'sign': 309 if calc == 'sign':
296 resultBit = prog.paramSize(prog.lastDst) - 1 310 resultBit = prog.paramSize(prog.lastDst) - 1
311 elif calc == 'carry':
312 resultBit = prog.paramSize(prog.lastDst)
313 elif calc == 'half':
314 resultBit = 4
315 myRes = '({a} ^ {b} ^ {res})'.format(a = prog.lastA, b = prog.lastB, res = lastDst)
297 else: 316 else:
298 resultBit = int(resultBit) 317 resultBit = int(resultBit)
299 if type(storage) is tuple: 318 if type(storage) is tuple:
300 reg,storageBit = storage 319 reg,storageBit = storage
301 reg = prog.resolveParam(reg, None, {}) 320 reg = prog.resolveParam(reg, None, {})
302 if storageBit == resultBit: 321 if storageBit == resultBit:
303 #TODO: optimize this case 322 #TODO: optimize this case
304 output.append('\n\t{reg} = ({reg} & ~{mask}U) | ({res} & {mask}U);'.format( 323 output.append('\n\t{reg} = ({reg} & ~{mask}U) | ({res} & {mask}U);'.format(
305 reg = reg, mask = 1 << resultBit, res = lastDst 324 reg = reg, mask = 1 << resultBit, res = myRes
306 )) 325 ))
307 else: 326 else:
308 if resultBit > storageBit: 327 if resultBit > storageBit:
309 op = '>>' 328 op = '>>'
310 shift = resultBit - storageBit 329 shift = resultBit - storageBit
311 else: 330 else:
312 op = '<<' 331 op = '<<'
313 shift = storageBit - resultBit 332 shift = storageBit - resultBit
314 output.append('\n\t{reg} = ({reg} & ~{mask}U) | ({res} {op} {shift}U & {mask}U);'.format( 333 output.append('\n\t{reg} = ({reg} & ~{mask}U) | ({res} {op} {shift}U & {mask}U);'.format(
315 reg = reg, mask = 1 << storageBit, res = lastDst, op = op, shift = shift 334 reg = reg, mask = 1 << storageBit, res = myRes, op = op, shift = shift
316 )) 335 ))
317 else: 336 else:
318 reg = prog.resolveParam(storage, None, {}) 337 reg = prog.resolveParam(storage, None, {})
319 output.append('\n\t{reg} = {res} & {mask}U;'.format(reg=reg, res=lastDst, mask = 1 << resultBit)) 338 output.append('\n\t{reg} = {res} & {mask}U;'.format(reg=reg, res=myRes, mask = 1 << resultBit))
320 elif calc == 'zero': 339 elif calc == 'zero':
321 if type(storage) is tuple: 340 if type(storage) is tuple:
322 reg,storageBit = storage 341 reg,storageBit = storage
323 reg = prog.resolveParam(reg, None, {}) 342 reg = prog.resolveParam(reg, None, {})
324 output.append('\n\t{reg} = {res} ? ({reg} & {mask}U) : ({reg} | {bit}U);'.format( 343 output.append('\n\t{reg} = {res} ? ({reg} & {mask}U) : ({reg} | {bit}U);'.format(
326 )) 345 ))
327 else: 346 else:
328 reg = prog.resolveParam(storage, None, {}) 347 reg = prog.resolveParam(storage, None, {})
329 output.append('\n\t{reg} = {res} == 0;'.format( 348 output.append('\n\t{reg} = {res} == 0;'.format(
330 reg = reg, res = lastDst 349 reg = reg, res = lastDst
331 )) 350 ))
332 elif calc == 'half-carry':
333 pass
334 elif calc == 'carry':
335 pass
336 elif calc == 'overflow': 351 elif calc == 'overflow':
337 pass 352 pass
338 elif calc == 'parity': 353 elif calc == 'parity':
339 pass 354 pass
355 else:
356 raise Exception('Unknown flag calc type: ' + calc)
357 if prog.carryFlowDst:
358 output.append('\n\t{dst} = {tmpdst};'.format(dst = prog.resolveParam(prog.lastDst, None, {}), tmpdst = prog.carryFlowDst))
359 prog.carryFlowDst = None
340 #TODO: combine explicit flags targeting the same storage location 360 #TODO: combine explicit flags targeting the same storage location
341 for flag in explicit: 361 for flag in explicit:
342 location = prog.flags.getStorage(flag) 362 location = prog.flags.getStorage(flag)
343 if type(location) is tuple: 363 if type(location) is tuple:
344 reg,bit = location 364 reg,bit = location
456 class NormalOp: 476 class NormalOp:
457 def __init__(self, parts): 477 def __init__(self, parts):
458 self.op = parts[0] 478 self.op = parts[0]
459 self.params = parts[1:] 479 self.params = parts[1:]
460 480
461 def generate(self, prog, parent, fieldVals, output, otype): 481 def generate(self, prog, parent, fieldVals, output, otype, flagUpdates):
462 procParams = [] 482 procParams = []
463 allParamsConst = True 483 allParamsConst = flagUpdates is None
464 opDef = _opMap.get(self.op) 484 opDef = _opMap.get(self.op)
465 for param in self.params: 485 for param in self.params:
466 allowConst = (self.op in prog.subroutines or len(procParams) != len(self.params) - 1) and param in parent.regValues 486 allowConst = (self.op in prog.subroutines or len(procParams) != len(self.params) - 1) and param in parent.regValues
467 isDst = (not opDef is None) and len(procParams) in opDef.outOp 487 isDst = (not opDef is None) and len(procParams) in opDef.outOp
468 param = prog.resolveParam(param, parent, fieldVals, allowConst, isDst) 488 param = prog.resolveParam(param, parent, fieldVals, allowConst, isDst)
500 dst = maybeLocal 520 dst = maybeLocal
501 parent.regValues[dst] = result 521 parent.regValues[dst] = result
502 if prog.isReg(dst): 522 if prog.isReg(dst):
503 shortProc = (procParams[0], procParams[-1]) 523 shortProc = (procParams[0], procParams[-1])
504 shortParams = (self.params[0], self.params[-1]) 524 shortParams = (self.params[0], self.params[-1])
505 output.append(_opMap['mov'].generate(otype, prog, shortProc, shortParams)) 525 output.append(_opMap['mov'].generate(otype, prog, shortProc, shortParams, None))
506 else: 526 else:
507 output.append(opDef.generate(otype, prog, procParams, self.params)) 527 output.append(opDef.generate(otype, prog, procParams, self.params, flagUpdates))
508 elif self.op in prog.subroutines: 528 elif self.op in prog.subroutines:
509 prog.subroutines[self.op].inline(prog, procParams, output, otype, parent) 529 prog.subroutines[self.op].inline(prog, procParams, output, otype, parent)
510 else: 530 else:
511 output.append('\n\t' + self.op + '(' + ', '.join([str(p) for p in procParams]) + ');') 531 output.append('\n\t' + self.op + '(' + ', '.join([str(p) for p in procParams]) + ');')
512 prog.lastOp = self 532 prog.lastOp = self
556 def localSize(self, name): 576 def localSize(self, name):
557 if name in self.current_locals: 577 if name in self.current_locals:
558 return self.current_locals[name] 578 return self.current_locals[name]
559 return self.parent.localSize(name) 579 return self.parent.localSize(name)
560 580
561 def generate(self, prog, parent, fieldVals, output, otype): 581 def generate(self, prog, parent, fieldVals, output, otype, flagUpdates):
562 prog.pushScope(self) 582 prog.pushScope(self)
563 param = prog.resolveParam(self.param, parent, fieldVals) 583 param = prog.resolveParam(self.param, parent, fieldVals)
564 if type(param) is int: 584 if type(param) is int:
565 self.regValues = self.parent.regValues 585 self.regValues = self.parent.regValues
566 if param in self.cases: 586 if param in self.cases:
567 self.current_locals = self.case_locals[param] 587 self.current_locals = self.case_locals[param]
568 output.append('\n\t{') 588 output.append('\n\t{')
569 for local in self.case_locals[param]: 589 for local in self.case_locals[param]:
570 output.append('\n\tuint{0}_t {1};'.format(self.case_locals[param][local], local)) 590 output.append('\n\tuint{0}_t {1};'.format(self.case_locals[param][local], local))
571 for op in self.cases[param]: 591 self.processOps(prog, fieldVals, output, otype, self.cases[param])
572 op.generate(prog, self, fieldVals, output, otype)
573 output.append('\n\t}') 592 output.append('\n\t}')
574 elif self.default: 593 elif self.default:
575 self.current_locals = self.default_locals 594 self.current_locals = self.default_locals
576 output.append('\n\t{') 595 output.append('\n\t{')
577 for local in self.default_locals: 596 for local in self.default_locals:
578 output.append('\n\tuint{0}_t {1};'.format(self.default[local], local)) 597 output.append('\n\tuint{0}_t {1};'.format(self.default[local], local))
579 for op in self.default: 598 self.processOps(prog, fieldVals, output, otype, self.default)
580 op.generate(prog, self, fieldVals, output, otype)
581 output.append('\n\t}') 599 output.append('\n\t}')
582 else: 600 else:
583 output.append('\n\tswitch(' + param + ')') 601 output.append('\n\tswitch(' + param + ')')
584 output.append('\n\t{') 602 output.append('\n\t{')
585 for case in self.cases: 603 for case in self.cases:
586 self.current_locals = self.case_locals[case] 604 self.current_locals = self.case_locals[case]
587 self.regValues = dict(self.parent.regValues) 605 self.regValues = dict(self.parent.regValues)
588 output.append('\n\tcase {0}U: '.format(case) + '{') 606 output.append('\n\tcase {0}U: '.format(case) + '{')
589 for local in self.case_locals[case]: 607 for local in self.case_locals[case]:
590 output.append('\n\tuint{0}_t {1};'.format(self.case_locals[case][local], local)) 608 output.append('\n\tuint{0}_t {1};'.format(self.case_locals[case][local], local))
591 for op in self.cases[case]: 609 self.processOps(prog, fieldVals, output, otype, self.cases[case])
592 op.generate(prog, self, fieldVals, output, otype)
593 output.append('\n\tbreak;') 610 output.append('\n\tbreak;')
594 output.append('\n\t}') 611 output.append('\n\t}')
595 if self.default: 612 if self.default:
596 self.current_locals = self.default_locals 613 self.current_locals = self.default_locals
597 self.regValues = dict(self.parent.regValues) 614 self.regValues = dict(self.parent.regValues)
598 output.append('\n\tdefault: {') 615 output.append('\n\tdefault: {')
599 for local in self.default_locals: 616 for local in self.default_locals:
600 output.append('\n\tuint{0}_t {1};'.format(self.default_locals[local], local)) 617 output.append('\n\tuint{0}_t {1};'.format(self.default_locals[local], local))
601 for op in self.default: 618 self.processOps(prog, fieldVals, output, otype, self.default)
602 op.generate(prog, self, fieldVals, output, otype)
603 output.append('\n\t}') 619 output.append('\n\t}')
604 prog.popScope() 620 prog.popScope()
605 621
606 def __str__(self): 622 def __str__(self):
607 keys = self.cases.keys() 623 keys = self.cases.keys()
664 680
665 def _genTrueBody(self, prog, fieldVals, output, otype): 681 def _genTrueBody(self, prog, fieldVals, output, otype):
666 self.curLocals = self.locals 682 self.curLocals = self.locals
667 for local in self.locals: 683 for local in self.locals:
668 output.append('\n\tuint{sz}_t {nm};'.format(sz=self.locals[local], nm=local)) 684 output.append('\n\tuint{sz}_t {nm};'.format(sz=self.locals[local], nm=local))
669 for op in self.body: 685 self.processOps(prog, fieldVals, output, otype, self.body)
670 op.generate(prog, self, fieldVals, output, otype)
671 686
672 def _genFalseBody(self, prog, fieldVals, output, otype): 687 def _genFalseBody(self, prog, fieldVals, output, otype):
673 self.curLocals = self.elseLocals 688 self.curLocals = self.elseLocals
674 for local in self.elseLocals: 689 for local in self.elseLocals:
675 output.append('\n\tuint{sz}_t {nm};'.format(sz=self.elseLocals[local], nm=local)) 690 output.append('\n\tuint{sz}_t {nm};'.format(sz=self.elseLocals[local], nm=local))
676 for op in self.elseBody: 691 self.processOps(prog, fieldVals, output, otype, self.elsebody)
677 op.generate(prog, self, fieldVals, output, otype)
678 692
679 def _genConstParam(self, param, prog, fieldVals, output, otype): 693 def _genConstParam(self, param, prog, fieldVals, output, otype):
680 if param: 694 if param:
681 self._genTrueBody(prog, fieldVals, output, otype) 695 self._genTrueBody(prog, fieldVals, output, otype)
682 else: 696 else:
683 self._genFalseBody(prog, fieldVals, output, otype) 697 self._genFalseBody(prog, fieldVals, output, otype)
684 698
685 def generate(self, prog, parent, fieldVals, output, otype): 699 def generate(self, prog, parent, fieldVals, output, otype, flagUpdates):
686 self.regValues = parent.regValues 700 self.regValues = parent.regValues
687 try: 701 try:
688 self._genConstParam(prog.checkBool(self.cond), prog, fieldVals, output, otype) 702 self._genConstParam(prog.checkBool(self.cond), prog, fieldVals, output, otype)
689 except Exception: 703 except Exception:
690 if self.cond in _ifCmpImpl[otype]: 704 if self.cond in _ifCmpImpl[otype]:
826 loc,_,bit = self.flagStorage[flag].partition('.') 840 loc,_,bit = self.flagStorage[flag].partition('.')
827 if bit: 841 if bit:
828 return (loc, int(bit)) 842 return (loc, int(bit))
829 else: 843 else:
830 return loc 844 return loc
845
846 def parseFlagUpdate(self, flagString):
847 last = ''
848 autoUpdate = set()
849 explicit = {}
850 for c in flagString:
851 if c.isdigit():
852 if last.isalpha():
853 num = int(c)
854 if num > 1:
855 raise Exception(c + ' is not a valid digit for update_flags')
856 explicit[last] = num
857 last = c
858 else:
859 raise Exception('Digit must follow flag letter in update_flags')
860 else:
861 if last.isalpha():
862 autoUpdate.add(last)
863 last = c
864 if last.isalpha():
865 autoUpdate.add(last)
866 return (autoUpdate, explicit)
831 867
832 def disperseFlags(self, prog, otype): 868 def disperseFlags(self, prog, otype):
833 bitToFlag = [None] * (self.maxBit+1) 869 bitToFlag = [None] * (self.maxBit+1)
834 src = prog.resolveReg(self.flagReg, None, {}) 870 src = prog.resolveReg(self.flagReg, None, {})
835 output = [] 871 output = []
947 self.flags = flags 983 self.flags = flags
948 self.lastDst = None 984 self.lastDst = None
949 self.scopes = [] 985 self.scopes = []
950 self.currentScope = None 986 self.currentScope = None
951 self.lastOp = None 987 self.lastOp = None
988 self.carryFlowDst = None
989 self.lastA = None
990 self.lastB = None
952 991
953 def __str__(self): 992 def __str__(self):
954 pieces = [] 993 pieces = []
955 for reg in self.regs: 994 for reg in self.regs:
956 pieces.append(str(self.regs[reg])) 995 pieces.append(str(self.regs[reg]))