From f4bb70cecf573e1c0a78e864c4fca1f650324baf Mon Sep 17 00:00:00 2001 From: matejcik Date: Thu, 9 Mar 2023 15:12:48 +0100 Subject: [PATCH] fix(core/protobuf): fix panic when a very large protobuf varint is received --- core/embed/rust/src/protobuf/decode.rs | 14 +++++++++++++- core/tests/test_trezor.protobuf.py | 20 +++++++++++++++++--- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/core/embed/rust/src/protobuf/decode.rs b/core/embed/rust/src/protobuf/decode.rs index ea62fee7af..25f1b3c8a6 100644 --- a/core/embed/rust/src/protobuf/decode.rs +++ b/core/embed/rust/src/protobuf/decode.rs @@ -315,14 +315,26 @@ impl<'a> InputStream<'a> { pub fn read_uvarint(&mut self) -> Result { let mut uint = 0; let mut shift = 0; - loop { + let mut last_byte = true; + while shift <= 64 - 7 { + // Shifting by 64 - 7 and then adding 7 bits is always safe. let byte = self.read_byte()?; uint += (byte as u64 & 0x7F) << shift; shift += 7; if byte & 0x80 == 0 { + last_byte = false; break; } } + if last_byte { + // After reading 9 bytes, there is only one bit remaining to be set. + let byte = self.read_byte()?; + if byte > 1 { + return Err(Error::OutOfRange); + } else { + uint += (byte as u64) << shift; + } + } Ok(uint) } } diff --git a/core/tests/test_trezor.protobuf.py b/core/tests/test_trezor.protobuf.py index 7f0c632886..0656ff17e9 100644 --- a/core/tests/test_trezor.protobuf.py +++ b/core/tests/test_trezor.protobuf.py @@ -1,7 +1,12 @@ from common import * from trezor import protobuf -from trezor.messages import WebAuthnCredential, Failure, SignMessage, DebugLinkMemoryRead +from trezor.messages import ( + WebAuthnCredential, + Failure, + SignMessage, + DebugLinkMemoryRead, +) def load_uvarint32(data: bytes) -> int: @@ -49,7 +54,9 @@ def dump_message(msg: protobuf.MessageType) -> bytearray: return buffer -def load_message(msg_type: Type[protobuf.MessageType], buffer: bytes) -> protobuf.MessageType: +def load_message( + msg_type: Type[protobuf.MessageType], buffer: bytes +) -> protobuf.MessageType: return protobuf.decode(buffer, msg_type, False) @@ -70,6 +77,14 @@ class TestProtobuf(unittest.TestCase): self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF) self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456) + self.assertEqual( + load_uvarint64(b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01"), + 0xFFFF_FFFF_FFFF_FFFF, + ) + with self.assertRaises(OverflowError): + i = load_uvarint64(b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02") + print(hex(i)) + def test_validate_enum(self): # ok message: msg = Failure(code=7) @@ -110,6 +125,5 @@ class TestProtobuf(unittest.TestCase): self.assertEqual(nmsg.coin_name, "Bitcoin") - if __name__ == "__main__": unittest.main()