changeset 1704:89932fd29abd

First stab at carry and half-carry calculation in CPU DSL
author Michael Pavone <pavone@retrodev.com>
date Mon, 28 Jan 2019 20:54:55 -0800
parents 49a52c737bf0
children 9ab64ef5cba0
files cpu_dsl.py
diffstat 1 files changed, 96 insertions(+), 57 deletions(-) [+]
line wrap: on
line diff
--- a/cpu_dsl.py	Mon Jan 28 19:24:04 2019 -0800
+++ b/cpu_dsl.py	Mon Jan 28 20:54:55 2019 -0800
@@ -20,6 +20,14 @@
 			self.addOp(NormalOp(parts))
 		return self
 		
+	def processOps(self, prog, fieldVals, output, otype, oplist):
+		for i in range(0, len(oplist)):
+			if i + 1 < len(oplist) and oplist[i+1].op == 'update_flags':
+				flagUpdates, _ = prog.flags.parseFlagUpdate(oplist[i+1].params[0])
+			else:
+				flagUpdates = None
+			oplist[i].generate(prog, self, fieldVals, output, otype, flagUpdates)
+		
 	def resolveLocal(self, name):
 		return None
 			
@@ -121,8 +129,7 @@
 			output.append('\n\tuint{sz}_t {name};'.format(sz=self.locals[var], name=var))
 		self.newLocals = []
 		fieldVals,_ = self.getFieldVals(value)
-		for op in self.implementation:
-			op.generate(prog, self, fieldVals, output, otype)
+		self.processOps(prog, fieldVals, output, otype, self.implementation)
 		begin = '\nvoid ' + self.generateName(value) + '(' + prog.context_type + ' *context)\n{'
 		if prog.needFlagCoalesce:
 			begin += prog.flags.coalesceFlags(prog, otype)
@@ -189,8 +196,7 @@
 		for name in self.locals:
 			size = self.locals[name]
 			output.append('\n\tuint{size}_t {sub}_{local};'.format(size=size, sub=self.name, local=name))
-		for op in self.implementation:
-			op.generate(prog, self, argValues, output, otype)
+		self.processOps(prog, argValues, output, otype, self.implementation)
 		prog.popScope()
 		
 	def __str__(self):
@@ -209,15 +215,36 @@
 		self.impls = {}
 		self.outOp = ()
 	def cBinaryOperator(self, op):
-		def _impl(prog, params):
+		def _impl(prog, params, rawParams, flagUpdates):
 			if op == '-':
 				a = params[1]
 				b = params[0]
 			else:
 				a = params[0]
 				b = params[1]
-			return '\n\t{dst} = {a} {op} {b};'.format(
-				dst = params[2], a = a, b = b, op = op
+			needsCarry = needsOflow = needsHalf = False
+			if flagUpdates:
+				for flag in flagUpdates:
+					calc = prog.flags.flagCalc[flag]
+					if calc == 'carry':
+						needsCarry = True
+					elif calc == 'half-carry':
+						needsHalf = True
+					elif calc == 'overflow':
+						needsOflow = True
+			decl = ''
+			if needsCarry or needsOflow or needsHalf:
+				size = prog.paramSize(rawParams[2])
+				if needsCarry:
+					size *= 2
+				decl,name = prog.getTemp(size)
+				dst = prog.carryFlowDst = name
+				prog.lastA = a
+				prog.lastB = b
+			else:
+				dst = params[2]
+			return decl + '\n\t{dst} = {a} {op} {b};'.format(
+				dst = dst, a = a, b = b, op = op
 			)
 		self.impls['c'] = _impl
 		self.outOp = (2,)
@@ -244,11 +271,13 @@
 		return not self.evalFun is None
 	def numArgs(self):
 		return self.evalFun.__code__.co_argcount
-	def generate(self, otype, prog, params, rawParams):
+	def generate(self, otype, prog, params, rawParams, flagUpdates):
 		if self.impls[otype].__code__.co_argcount == 2:
 			return self.impls[otype](prog, params)
+		elif self.impls[otype].__code__.co_argcount == 3:
+			return self.impls[otype](prog, params, rawParams)
 		else:
-			return self.impls[otype](prog, params, rawParams)
+			return self.impls[otype](prog, params, rawParams, flagUpdates)
 		
 		
 def _xchgCImpl(prog, params, rawParams):
@@ -264,36 +293,26 @@
 	return '\n\timpl_{tbl}[{op}](context);'.format(tbl = table, op = params[0])
 
 def _updateFlagsCImpl(prog, params, rawParams):
-	i = 0
-	last = ''
-	autoUpdate = set()
-	explicit = {}
-	for c in params[0]:
-		if c.isdigit():
-			if last.isalpha():
-				num = int(c)
-				if num > 1:
-					raise Exception(c + ' is not a valid digit for update_flags')
-				explicit[last] = num
-				last = c
-			else:
-				raise Exception('Digit must follow flag letter in update_flags')
-		else:
-			if last.isalpha():
-				autoUpdate.add(last)
-			last = c
-	if last.isalpha():
-		autoUpdate.add(last)
+	autoUpdate, explicit = prog.flags.parseFlagUpdate(params[0])
 	output = []
 	#TODO: handle autoUpdate flags
 	for flag in autoUpdate:
 		calc = prog.flags.flagCalc[flag]
 		calc,_,resultBit = calc.partition('-')
-		lastDst = prog.resolveParam(prog.lastDst, None, {})
+		if prog.carryFlowDst:
+			lastDst = prog.carryFlowDst
+		else:
+			lastDst = prog.resolveParam(prog.lastDst, None, {})
 		storage = prog.flags.getStorage(flag)
-		if calc == 'bit' or calc == 'sign':
+		if calc == 'bit' or calc == 'sign' or calc == 'carry' or calc == 'half':
+			myRes = lastDst
 			if calc == 'sign':
 				resultBit = prog.paramSize(prog.lastDst) - 1
+			elif calc == 'carry':
+				resultBit = prog.paramSize(prog.lastDst)
+			elif calc == 'half':
+				resultBit = 4
+				myRes = '({a} ^ {b} ^ {res})'.format(a = prog.lastA, b = prog.lastB, res = lastDst)
 			else:
 				resultBit = int(resultBit)
 			if type(storage) is tuple:
@@ -302,7 +321,7 @@
 				if storageBit == resultBit:
 					#TODO: optimize this case
 					output.append('\n\t{reg} = ({reg} & ~{mask}U) | ({res} & {mask}U);'.format(
-						reg = reg, mask = 1 << resultBit, res = lastDst
+						reg = reg, mask = 1 << resultBit, res = myRes
 					))
 				else:
 					if resultBit > storageBit:
@@ -312,11 +331,11 @@
 						op = '<<'
 						shift = storageBit - resultBit
 					output.append('\n\t{reg} = ({reg} & ~{mask}U) | ({res} {op} {shift}U & {mask}U);'.format(
-						reg = reg, mask = 1 << storageBit, res = lastDst, op = op, shift = shift
+						reg = reg, mask = 1 << storageBit, res = myRes, op = op, shift = shift
 					))
 			else:
 				reg = prog.resolveParam(storage, None, {})
-				output.append('\n\t{reg} = {res} & {mask}U;'.format(reg=reg, res=lastDst, mask = 1 << resultBit))
+				output.append('\n\t{reg} = {res} & {mask}U;'.format(reg=reg, res=myRes, mask = 1 << resultBit))
 		elif calc == 'zero':
 			if type(storage) is tuple:
 				reg,storageBit = storage
@@ -328,15 +347,16 @@
 				reg = prog.resolveParam(storage, None, {})
 				output.append('\n\t{reg} = {res} == 0;'.format(
 					reg = reg, res = lastDst
-				))
-		elif calc == 'half-carry':
-			pass
-		elif calc == 'carry':
-			pass
+				))			
 		elif calc == 'overflow':
 			pass
 		elif calc == 'parity':
 			pass
+		else:
+			raise Exception('Unknown flag calc type: ' + calc)
+	if prog.carryFlowDst:
+		output.append('\n\t{dst} = {tmpdst};'.format(dst = prog.resolveParam(prog.lastDst, None, {}), tmpdst = prog.carryFlowDst))
+		prog.carryFlowDst = None
 	#TODO: combine explicit flags targeting the same storage location
 	for flag in explicit:
 		location = prog.flags.getStorage(flag)
@@ -458,9 +478,9 @@
 		self.op = parts[0]
 		self.params = parts[1:]
 		
-	def generate(self, prog, parent, fieldVals, output, otype):
+	def generate(self, prog, parent, fieldVals, output, otype, flagUpdates):
 		procParams = []
-		allParamsConst = True
+		allParamsConst = flagUpdates is None
 		opDef = _opMap.get(self.op)
 		for param in self.params:
 			allowConst = (self.op in prog.subroutines or len(procParams) != len(self.params) - 1) and param in parent.regValues
@@ -502,9 +522,9 @@
 				if prog.isReg(dst):
 					shortProc = (procParams[0], procParams[-1])
 					shortParams = (self.params[0], self.params[-1])
-					output.append(_opMap['mov'].generate(otype, prog, shortProc, shortParams))
+					output.append(_opMap['mov'].generate(otype, prog, shortProc, shortParams, None))
 			else:
-				output.append(opDef.generate(otype, prog, procParams, self.params))
+				output.append(opDef.generate(otype, prog, procParams, self.params, flagUpdates))
 		elif self.op in prog.subroutines:
 			prog.subroutines[self.op].inline(prog, procParams, output, otype, parent)
 		else:
@@ -558,7 +578,7 @@
 			return self.current_locals[name]
 		return self.parent.localSize(name)
 			
-	def generate(self, prog, parent, fieldVals, output, otype):
+	def generate(self, prog, parent, fieldVals, output, otype, flagUpdates):
 		prog.pushScope(self)
 		param = prog.resolveParam(self.param, parent, fieldVals)
 		if type(param) is int:
@@ -568,16 +588,14 @@
 				output.append('\n\t{')
 				for local in self.case_locals[param]:
 					output.append('\n\tuint{0}_t {1};'.format(self.case_locals[param][local], local))
-				for op in self.cases[param]:
-					op.generate(prog, self, fieldVals, output, otype)
+				self.processOps(prog, fieldVals, output, otype, self.cases[param])
 				output.append('\n\t}')
 			elif self.default:
 				self.current_locals = self.default_locals
 				output.append('\n\t{')
 				for local in self.default_locals:
 					output.append('\n\tuint{0}_t {1};'.format(self.default[local], local))
-				for op in self.default:
-					op.generate(prog, self, fieldVals, output, otype)
+				self.processOps(prog, fieldVals, output, otype, self.default)
 				output.append('\n\t}')
 		else:
 			output.append('\n\tswitch(' + param + ')')
@@ -588,8 +606,7 @@
 				output.append('\n\tcase {0}U: '.format(case) + '{')
 				for local in self.case_locals[case]:
 					output.append('\n\tuint{0}_t {1};'.format(self.case_locals[case][local], local))
-				for op in self.cases[case]:
-					op.generate(prog, self, fieldVals, output, otype)
+				self.processOps(prog, fieldVals, output, otype, self.cases[case])
 				output.append('\n\tbreak;')
 				output.append('\n\t}')
 			if self.default:
@@ -598,8 +615,7 @@
 				output.append('\n\tdefault: {')
 				for local in self.default_locals:
 					output.append('\n\tuint{0}_t {1};'.format(self.default_locals[local], local))
-				for op in self.default:
-					op.generate(prog, self, fieldVals, output, otype)
+				self.processOps(prog, fieldVals, output, otype, self.default)
 			output.append('\n\t}')
 		prog.popScope()
 	
@@ -666,15 +682,13 @@
 		self.curLocals = self.locals
 		for local in self.locals:
 			output.append('\n\tuint{sz}_t {nm};'.format(sz=self.locals[local], nm=local))
-		for op in self.body:
-			op.generate(prog, self, fieldVals, output, otype)
+		self.processOps(prog, fieldVals, output, otype, self.body)
 			
 	def _genFalseBody(self, prog, fieldVals, output, otype):
 		self.curLocals = self.elseLocals
 		for local in self.elseLocals:
 			output.append('\n\tuint{sz}_t {nm};'.format(sz=self.elseLocals[local], nm=local))
-		for op in self.elseBody:
-			op.generate(prog, self, fieldVals, output, otype)
+		self.processOps(prog, fieldVals, output, otype, self.elsebody)
 	
 	def _genConstParam(self, param, prog, fieldVals, output, otype):
 		if param:
@@ -682,7 +696,7 @@
 		else:
 			self._genFalseBody(prog, fieldVals, output, otype)
 			
-	def generate(self, prog, parent, fieldVals, output, otype):
+	def generate(self, prog, parent, fieldVals, output, otype, flagUpdates):
 		self.regValues = parent.regValues
 		try:
 			self._genConstParam(prog.checkBool(self.cond), prog, fieldVals, output, otype)
@@ -829,6 +843,28 @@
 		else:
 			return loc 
 	
+	def parseFlagUpdate(self, flagString):
+		last = ''
+		autoUpdate = set()
+		explicit = {}
+		for c in flagString:
+			if c.isdigit():
+				if last.isalpha():
+					num = int(c)
+					if num > 1:
+						raise Exception(c + ' is not a valid digit for update_flags')
+					explicit[last] = num
+					last = c
+				else:
+					raise Exception('Digit must follow flag letter in update_flags')
+			else:
+				if last.isalpha():
+					autoUpdate.add(last)
+				last = c
+		if last.isalpha():
+			autoUpdate.add(last)
+		return (autoUpdate, explicit)
+	
 	def disperseFlags(self, prog, otype):
 		bitToFlag = [None] * (self.maxBit+1)
 		src = prog.resolveReg(self.flagReg, None, {})
@@ -949,6 +985,9 @@
 		self.scopes = []
 		self.currentScope = None
 		self.lastOp = None
+		self.carryFlowDst = None
+		self.lastA = None
+		self.lastB = None
 		
 	def __str__(self):
 		pieces = []