diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index 773bde90..3bae20b7 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -37,6 +37,7 @@ import { ConstructNotSupportedError, NoMethodMatchingSignatureError } from './error' +import { unannTypeToString } from '../types/ast/utils' import { FieldInfo, MethodInfos, SymbolInfo, SymbolTable, VariableInfo } from './symbol-table' type Label = { @@ -576,6 +577,181 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi return { stackSize: maxStack, resultType: resType } }, + TryStatement: (node: Node, cg: CodeGenerator) => { + let maxStack = 0 + const { block, catches } = node as any + const finallyNode: any = (node as any).finally + + const hasCatches = catches && catches.catchClauses && catches.catchClauses.length > 0 + + if (!hasCatches && !finallyNode) { + return { stackSize: compile(block, cg).stackSize, resultType: EMPTY_TYPE } + } + + if (hasCatches || finallyNode) { + maxStack = Math.max(maxStack, 1) + } + + const localExceptionTable: Array<{ + startPc: number + endPc: number + handlerLabel: Label + catchType: number + }> = [] + + // mark start of protected region + const tryStart = cg.generateNewLabel() + tryStart.offset = cg.code.length + + // compile try block + maxStack = Math.max(maxStack, compile(block, cg).stackSize) + + // end of protected region (first instruction after try block) + const tryEnd = cg.generateNewLabel() + tryEnd.offset = cg.code.length + + const catchAllLabel = finallyNode ? cg.generateNewLabel() : null + + // If finally exists, add catch-all entry for the try block + if (finallyNode && catchAllLabel) { + localExceptionTable.push({ + startPc: tryStart.offset, + endPc: tryEnd.offset, + handlerLabel: catchAllLabel, + catchType: 0 + }) + } + + // For normal path: run finally block if it exists + if (finallyNode) { + finallyNode.blockStatements.forEach((stmt: any) => { + const { stackSize } = compile(stmt, cg) + maxStack = Math.max(maxStack, stackSize) + }) + } + + // jump over handlers when try completes normally + const afterHandlers = cg.generateNewLabel() + cg.addBranchInstr(OPCODE.GOTO, afterHandlers) + + // For each catch clause, emit a handler and an exception table entry + if (hasCatches) { + for (const catchClause of catches.catchClauses) { + const handlerLabel = cg.generateNewLabel() + handlerLabel.offset = cg.code.length + + // determine catch type index (constant pool) + const catchTypeNode = catchClause.catchFormalParameter.catchType + const catchTypeName = unannTypeToString(catchTypeNode.unannClassType) + let catchClassName = 'java/lang/Throwable' + try { + catchClassName = cg.symbolTable.queryClass(catchTypeName).name + } catch (e) { + catchClassName = catchTypeName.includes('/') ? catchTypeName : catchTypeName.replace(/\./g, '/') + } + const catchTypeIndex = cg.constantPoolManager.indexClassInfo(catchClassName) + + // add exception table entry (startPc, endPc, handlerPc, catchType) + localExceptionTable.push({ + startPc: tryStart.offset, + endPc: tryEnd.offset, + handlerLabel: handlerLabel, + catchType: catchTypeIndex + }) + + // create scope for catch variable + cg.symbolTable.extend() + const varName = catchClause.catchFormalParameter.variableDeclaratorId + const varTypeStr = unannTypeToString(catchTypeNode.unannClassType) + const varInfo = { + name: varName, + accessFlags: 0, + index: cg.maxLocals, + typeName: varTypeStr, + typeDescriptor: cg.symbolTable.generateFieldDescriptor(varTypeStr) + } + cg.symbolTable.insertVariableInfo(varInfo) + if (['J', 'D'].includes(varInfo.typeDescriptor)) { + cg.maxLocals += 2 + } else { + cg.maxLocals++ + } + + // at handler entry, the exception object is on the stack; store it into the local + cg.code.push(OPCODE.ASTORE, varInfo.index) + + const catchStartOffset = cg.code.length + + // compile catch block statements + const catchBlock = catchClause.block + catchBlock.blockStatements.forEach((stmt: any) => { + const { stackSize } = compile(stmt, cg) + maxStack = Math.max(maxStack, stackSize) + }) + + const catchEndOffset = cg.code.length + + // teardown catch scope + cg.symbolTable.teardown() + + // If finally exists, add catch-all entry for this catch block + if (finallyNode && catchAllLabel && catchStartOffset < catchEndOffset) { + localExceptionTable.push({ + startPc: catchStartOffset, + endPc: catchEndOffset, + handlerLabel: catchAllLabel, + catchType: 0 + }) + } + + // For caught path: run finally block if it exists + if (finallyNode) { + finallyNode.blockStatements.forEach((stmt: any) => { + const { stackSize } = compile(stmt, cg) + maxStack = Math.max(maxStack, stackSize) + }) + } + + // after handler, jump to afterHandlers + cg.addBranchInstr(OPCODE.GOTO, afterHandlers) + } + } + + // If finally exists, add a catch-all handler that runs finally then rethrows + if (finallyNode && catchAllLabel) { + catchAllLabel.offset = cg.code.length + + // allocate temp local to store exception + const tempIndex = cg.maxLocals + cg.maxLocals += 1 + cg.code.push(OPCODE.ASTORE, tempIndex) + + // compile finally block inside catch-all + finallyNode.blockStatements.forEach((stmt: any) => { + const { stackSize } = compile(stmt, cg) + maxStack = Math.max(maxStack, stackSize) + }) + + // reload exception and rethrow + cg.code.push(OPCODE.ALOAD, tempIndex, OPCODE.ATHROW) + } + + // place after-handlers label + afterHandlers.offset = cg.code.length + + // Now that all labels are resolved, push to cg.exceptionTable + localExceptionTable.forEach(entry => { + cg.exceptionTable.push({ + startPc: entry.startPc, + endPc: entry.endPc, + handlerPc: entry.handlerLabel.offset, + catchType: entry.catchType + }) + }) + + return { stackSize: maxStack, resultType: EMPTY_TYPE } + }, + TernaryExpression: (node: Node, cg: CodeGenerator) => { let maxStack = 0 const { @@ -1723,6 +1899,7 @@ class CodeGenerator { constantPoolManager: ConstantPoolManager maxLocals: number = 0 stackSize: number = 0 + exceptionTable: Array = [] labels: Label[] = [] loopLabels: Label[][] = [] switchLabels: Label[] = [] @@ -1761,6 +1938,7 @@ class CodeGenerator { generateCode(currentClass: string, methodNode: MethodDeclaration) { this.symbolTable.extend() this.currentClass = currentClass + this.exceptionTable = [] if (!methodNode.methodModifier.includes('static')) { this.maxLocals++ } @@ -1799,7 +1977,6 @@ class CodeGenerator { } this.resolveLabels() - const exceptionTable: Array = [] const attributes: Array = [] const codeBuf = new Uint8Array(this.code).buffer const dataView = new DataView(codeBuf) @@ -1808,7 +1985,7 @@ class CodeGenerator { const attributeLength = 12 + this.code.length + - 8 * exceptionTable.length + + 8 * this.exceptionTable.length + attributes.map(attr => attr.attributeLength + 6).reduce((acc, val) => acc + val, 0) this.symbolTable.teardown() @@ -1819,8 +1996,8 @@ class CodeGenerator { maxLocals: this.maxLocals, codeLength: this.code.length, code: dataView, - exceptionTableLength: exceptionTable.length, - exceptionTable: exceptionTable, + exceptionTableLength: this.exceptionTable.length, + exceptionTable: this.exceptionTable, attributesCount: attributes.length, attributes: attributes } diff --git a/src/jvm/__tests__/thread.ts b/src/jvm/__tests__/thread.ts index e4ae911e..07395191 100644 --- a/src/jvm/__tests__/thread.ts +++ b/src/jvm/__tests__/thread.ts @@ -4,7 +4,9 @@ import { ReferenceClassData } from '../types/class/ClassData' import { JvmObject } from '../types/reference/Object' import Thread from '../../jvm/thread' import JVM from '../../jvm/jvm' +import { JavaStackFrame } from '../../jvm/stackframe' import { setupTest, TestThreadPool } from './__utils__/test-utils' +import { METHOD_FLAGS } from '../../ClassFile/types/methods' let thread: Thread let threadClass: ReferenceClassData @@ -67,4 +69,42 @@ describe('Thread', () => { test('should manage wide (64-bit) values on the operand stack correctly', () => { // TODO }) + + test('should route an exception to a matching try-catch handler in the current method', () => { + const setup = setupTest() + const { testLoader, thread: testThread, classes } = setup + const exceptionMethodClass = testLoader.createClass({ + className: 'TryCatchTest', + loader: testLoader, + methods: [ + { + accessFlags: [METHOD_FLAGS.ACC_PUBLIC], + name: 'test0', + descriptor: '()V', + attributes: [], + code: new DataView(new ArrayBuffer(1)), + exceptionTable: [ + { + startPc: 0, + endPc: 1, + handlerPc: 0, + catchType: 'java/lang/NullPointerException' + } + ] + } + ], + }) as ReferenceClassData + + const method = exceptionMethodClass.getMethod('test0()V') + expect(method).not.toBeNull() + + testThread.invokeStackFrame( + new JavaStackFrame(exceptionMethodClass, method as any, 0, []) + ) + const exceptionObj = classes.NullPointerException.instantiate() + testThread.throwException(exceptionObj) + + expect(testThread.getPC()).toBe(0) + expect(testThread.peekStackFrame().operandStack).toEqual([exceptionObj]) + }) }) diff --git a/src/jvm/exception-table.ts b/src/jvm/exception-table.ts index 15248a87..7bcf0d99 100644 --- a/src/jvm/exception-table.ts +++ b/src/jvm/exception-table.ts @@ -1,33 +1,45 @@ -import { ClassData } from "./types/class/ClassData" - -class Entry { - from: number - to: number - target: number - type: ClassData - - constructor(from: number, to: number, target: number, type: ClassData) { - this.from = from; - this.to = to; - this.target = target; - this.type = type; - } +import { ClassData } from './types/class/ClassData' + +export interface ExceptionTableEntry { + startPc: number + endPc: number + handlerPc: number + catchType: any | null } -export class ExceptionTable { - private entries: Entry[] +export class ExceptionTable implements Iterable { + private entries: ExceptionTableEntry[] + + constructor(entries?: ExceptionTableEntry[]) { + this.entries = entries ? entries.slice() : [] + } - retrieve(line: number): Entry | null { - this.entries.forEach(entry => { - if (line >= entry.from && line <= entry.to) { - return entry + retrieve(pc: number): ExceptionTableEntry | null { + for (let i = 0; i < this.entries.length; i++) { + const e = this.entries[i] + if (pc >= e.startPc && pc < e.endPc) { + return e } - }) + } return null } - insert(from: number, to: number, target: number, type: ClassData): void { - var entry = new Entry(from, to, target, type) - this.entries.push(entry) + insert(startPc: number, endPc: number, handlerPc: number, catchType: ClassData | null): void { + this.entries.push({ startPc, endPc, handlerPc, catchType }) + } + + toArray(): ExceptionTableEntry[] { + return this.entries.slice() + } + + [Symbol.iterator](): Iterator { + return this.entries[Symbol.iterator]() + } + forEach(cb: (entry: ExceptionTableEntry, idx?: number) => void) { + this.entries.forEach(cb) + } + + get length() { + return this.entries.length } } \ No newline at end of file diff --git a/src/jvm/types/class/Attributes.ts b/src/jvm/types/class/Attributes.ts index 92b5ee33..f3f2a051 100644 --- a/src/jvm/types/class/Attributes.ts +++ b/src/jvm/types/class/Attributes.ts @@ -15,6 +15,7 @@ import { SourceFileAttribute, StackMapFrame } from '../../../ClassFile/types/attributes' +import { ExceptionTable } from '../../exception-table' import { ConstantPool } from '../../constant-pool' import { ConstantClass, @@ -45,7 +46,8 @@ export const info2Attribute = (info: AttributeInfo, constantPool: ConstantPool): case 'Code': const code = info as CodeAttribute const attr: { [attributeName: string]: IAttribute } = {} - const exceptionTable = code.exceptionTable.map(handler => { + const exceptionTable = new ExceptionTable( + code.exceptionTable.map(handler => { return { startPc: handler.startPc, endPc: handler.endPc, @@ -54,6 +56,7 @@ export const info2Attribute = (info: AttributeInfo, constantPool: ConstantPool): handler.catchType === 0 ? null : (constantPool.get(handler.catchType) as ConstantClass) } }) + ) code.attributes.forEach(element => { attr[(constantPool.get(element.attributeNameIndex) as ConstantUtf8).get()] = info2Attribute( element, @@ -244,12 +247,7 @@ export interface Code extends IAttribute { codeLength: number code: DataView exceptionTableLength: number - exceptionTable: Array<{ - startPc: number - endPc: number - handlerPc: number - catchType: ConstantClass | null - }> + exceptionTable: ExceptionTable attributes: { [attributeName: string]: IAttribute } diff --git a/src/jvm/types/class/Method.ts b/src/jvm/types/class/Method.ts index b0adb49c..8e803643 100644 --- a/src/jvm/types/class/Method.ts +++ b/src/jvm/types/class/Method.ts @@ -6,6 +6,7 @@ import { attrInfo2Interface, parseMethodDescriptor, getArgs, logger } from '../. import { ErrorResult, ImmediateResult, ResultType, SuccessResult } from '../Result' import { JavaType, JvmObject } from '../reference/Object' import { Code, Exceptions, IAttribute, NestHost, Signature } from './Attributes' +import { ExceptionTable } from '../../exception-table' import { ReferenceClassData, ArrayClassData, ClassData } from './ClassData' import { ConstantClass, ConstantMethodref, ConstantNameAndType, ConstantUtf8 } from './Constants' @@ -484,7 +485,7 @@ export class Method { codeLength: dv.buffer.byteLength, code: dv, exceptionTableLength: 0, - exceptionTable: [], + exceptionTable: new ExceptionTable(), attributes: {} } as Code }, diff --git a/src/jvm/utils/disassembler/utils/readAttributes.ts b/src/jvm/utils/disassembler/utils/readAttributes.ts index a2c624fa..0f94bdea 100644 --- a/src/jvm/utils/disassembler/utils/readAttributes.ts +++ b/src/jvm/utils/disassembler/utils/readAttributes.ts @@ -186,7 +186,7 @@ function readCodeAttribute( throw new Error('Class format error: Code attribute invalid length') } - const code = new DataView(view.buffer, offset, codeLength) + const code = new DataView(view.buffer, view.byteOffset + offset, codeLength) offset += codeLength const exceptionTableLength = view.getUint16(offset)