diff --git a/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoReader.swift b/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoReader.swift index a1935b11ad..0951b98aff 100644 --- a/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoReader.swift +++ b/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoReader.swift @@ -183,8 +183,7 @@ public final class ProtoReader { throw ProtoDecoder.Error.unexpectedEndGroupFieldNumber(expected: nil, found: tag) case .lengthDelimited: - let length = try Int32(truncatingIfNeeded: buffer.readVarint()) - state = .lengthDelimited(length: Int(length)) + state = .lengthDelimited(length: try readLengthDelimitedLength(tag: tag)) currentTag = tag return tag @@ -793,8 +792,7 @@ public final class ProtoReader { return case .lengthDelimited: - let length = try Int32(truncatingIfNeeded: buffer.readVarint()) - state = .lengthDelimited(length: Int(length)) + state = .lengthDelimited(length: try readLengthDelimitedLength(tag: tag)) let data = try readData() try unknownFieldsWriter.encode(tag: tag, value: data) @@ -856,6 +854,16 @@ public final class ProtoReader { return (tag, wireType) } + private func readLengthDelimitedLength(tag: UInt32) throws -> Int { + let length = try Int32(truncatingIfNeeded: buffer.readVarint()) + if length < 0 { + throw ProtoDecoder.Error.invalidStructure( + message: "Negative length: \(length). Reader position: \(buffer.position). Last read tag: \(tag)." + ) + } + return Int(length) + } + // MARK: - Private Methods - Decoding - Repeated Field private func decode(into array: inout [T], decode: () throws -> T?) throws { diff --git a/wire-runtime-swift/src/main/swift/ProtoCodable/ReadBuffer.swift b/wire-runtime-swift/src/main/swift/ProtoCodable/ReadBuffer.swift index 44f987db12..605578e6e7 100644 --- a/wire-runtime-swift/src/main/swift/ProtoCodable/ReadBuffer.swift +++ b/wire-runtime-swift/src/main/swift/ProtoCodable/ReadBuffer.swift @@ -67,7 +67,7 @@ final class ReadBuffer { } func verifyAdditional(count: Int) throws { - guard pointer.advanced(by: count) <= end else { + guard count >= 0, pointer.advanced(by: count) <= end else { throw ProtoDecoder.Error.unexpectedEndOfData } } diff --git a/wire-runtime-swift/src/test/swift/ProtoReaderTests.swift b/wire-runtime-swift/src/test/swift/ProtoReaderTests.swift index 1b27b8aa5a..7093759b45 100644 --- a/wire-runtime-swift/src/test/swift/ProtoReaderTests.swift +++ b/wire-runtime-swift/src/test/swift/ProtoReaderTests.swift @@ -1111,6 +1111,42 @@ final class ProtoReaderTests: XCTestCase { } } + func testLengthDelimitedRejectsNegativeLength() throws { + let data = Foundation.Data(hexEncoded: """ + 0A // (Tag 1 | Length Delimited) + 80FFFFFF0F // Length -128 + """)! + + XCTAssertThrowsError( + try test(data: data) { reader in + _ = try reader.forEachTag { _ in + XCTFail("The negative length should have thrown before returning a tag") + } + } + ) { error in + assertNegativeLengthError(error, readerPosition: 6) + } + } + + func testSkipGroupRejectsNegativeLengthDelimited() throws { + let data = Foundation.Data(hexEncoded: """ + 9B06 // (Tag 99 | Start Group) + 0A // (Tag 1 | Length Delimited) + 80FFFFFF0F // Length -128 + 9C06 // (Tag 99 | End Group) + """)! + + XCTAssertThrowsError( + try test(data: data) { reader in + _ = try reader.forEachTag { _ in + XCTFail("The group should have been skipped or rejected") + } + } + ) { error in + assertNegativeLengthError(error, readerPosition: 8) + } + } + // MARK: - Tests - Unknown Fields func testUnknownFields() throws { @@ -1330,6 +1366,14 @@ final class ProtoReaderTests: XCTestCase { try test(reader) } } + + private func assertNegativeLengthError(_ error: Error, readerPosition: Int) { + guard case let ProtoDecoder.Error.invalidStructure(message) = error else { + XCTFail("Unexpected error: \(error)") + return + } + XCTAssertEqual(message, "Negative length: -128. Reader position: \(readerPosition). Last read tag: 1.") + } } // MARK: -