fix(core/protobuf): fix panic when a very large protobuf varint is received

pull/2871/head
matejcik 1 year ago
parent 79632d32c7
commit f4bb70cecf

@ -315,14 +315,26 @@ impl<'a> InputStream<'a> {
pub fn read_uvarint(&mut self) -> Result<u64, Error> { pub fn read_uvarint(&mut self) -> Result<u64, Error> {
let mut uint = 0; let mut uint = 0;
let mut shift = 0; let mut shift = 0;
loop { let mut last_byte = true;
while shift <= 64 - 7 {
// Shifting by 64 - 7 and then adding 7 bits is always safe.
let byte = self.read_byte()?; let byte = self.read_byte()?;
uint += (byte as u64 & 0x7F) << shift; uint += (byte as u64 & 0x7F) << shift;
shift += 7; shift += 7;
if byte & 0x80 == 0 { if byte & 0x80 == 0 {
last_byte = false;
break; break;
} }
} }
if last_byte {
// After reading 9 bytes, there is only one bit remaining to be set.
let byte = self.read_byte()?;
if byte > 1 {
return Err(Error::OutOfRange);
} else {
uint += (byte as u64) << shift;
}
}
Ok(uint) Ok(uint)
} }
} }

@ -1,7 +1,12 @@
from common import * from common import *
from trezor import protobuf from trezor import protobuf
from trezor.messages import WebAuthnCredential, Failure, SignMessage, DebugLinkMemoryRead from trezor.messages import (
WebAuthnCredential,
Failure,
SignMessage,
DebugLinkMemoryRead,
)
def load_uvarint32(data: bytes) -> int: def load_uvarint32(data: bytes) -> int:
@ -49,7 +54,9 @@ def dump_message(msg: protobuf.MessageType) -> bytearray:
return buffer return buffer
def load_message(msg_type: Type[protobuf.MessageType], buffer: bytes) -> protobuf.MessageType: def load_message(
msg_type: Type[protobuf.MessageType], buffer: bytes
) -> protobuf.MessageType:
return protobuf.decode(buffer, msg_type, False) return protobuf.decode(buffer, msg_type, False)
@ -70,6 +77,14 @@ class TestProtobuf(unittest.TestCase):
self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF) self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF)
self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456) self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456)
self.assertEqual(
load_uvarint64(b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01"),
0xFFFF_FFFF_FFFF_FFFF,
)
with self.assertRaises(OverflowError):
i = load_uvarint64(b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02")
print(hex(i))
def test_validate_enum(self): def test_validate_enum(self):
# ok message: # ok message:
msg = Failure(code=7) msg = Failure(code=7)
@ -110,6 +125,5 @@ class TestProtobuf(unittest.TestCase):
self.assertEqual(nmsg.coin_name, "Bitcoin") self.assertEqual(nmsg.coin_name, "Bitcoin")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

Loading…
Cancel
Save