mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-07-14 10:38:09 +00:00
fix(core/protobuf): fix panic when a very large protobuf varint is received
This commit is contained in:
parent
79632d32c7
commit
f4bb70cecf
@ -315,14 +315,26 @@ impl<'a> InputStream<'a> {
|
|||||||
pub fn read_uvarint(&mut self) -> Result<u64, Error> {
|
pub fn read_uvarint(&mut self) -> Result<u64, Error> {
|
||||||
let mut uint = 0;
|
let mut uint = 0;
|
||||||
let mut shift = 0;
|
let mut shift = 0;
|
||||||
loop {
|
let mut last_byte = true;
|
||||||
|
while shift <= 64 - 7 {
|
||||||
|
// Shifting by 64 - 7 and then adding 7 bits is always safe.
|
||||||
let byte = self.read_byte()?;
|
let byte = self.read_byte()?;
|
||||||
uint += (byte as u64 & 0x7F) << shift;
|
uint += (byte as u64 & 0x7F) << shift;
|
||||||
shift += 7;
|
shift += 7;
|
||||||
if byte & 0x80 == 0 {
|
if byte & 0x80 == 0 {
|
||||||
|
last_byte = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if last_byte {
|
||||||
|
// After reading 9 bytes, there is only one bit remaining to be set.
|
||||||
|
let byte = self.read_byte()?;
|
||||||
|
if byte > 1 {
|
||||||
|
return Err(Error::OutOfRange);
|
||||||
|
} else {
|
||||||
|
uint += (byte as u64) << shift;
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(uint)
|
Ok(uint)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,12 @@
|
|||||||
from common import *
|
from common import *
|
||||||
|
|
||||||
from trezor import protobuf
|
from trezor import protobuf
|
||||||
from trezor.messages import WebAuthnCredential, Failure, SignMessage, DebugLinkMemoryRead
|
from trezor.messages import (
|
||||||
|
WebAuthnCredential,
|
||||||
|
Failure,
|
||||||
|
SignMessage,
|
||||||
|
DebugLinkMemoryRead,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_uvarint32(data: bytes) -> int:
|
def load_uvarint32(data: bytes) -> int:
|
||||||
@ -49,7 +54,9 @@ def dump_message(msg: protobuf.MessageType) -> bytearray:
|
|||||||
return buffer
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
def load_message(msg_type: Type[protobuf.MessageType], buffer: bytes) -> protobuf.MessageType:
|
def load_message(
|
||||||
|
msg_type: Type[protobuf.MessageType], buffer: bytes
|
||||||
|
) -> protobuf.MessageType:
|
||||||
return protobuf.decode(buffer, msg_type, False)
|
return protobuf.decode(buffer, msg_type, False)
|
||||||
|
|
||||||
|
|
||||||
@ -70,6 +77,14 @@ class TestProtobuf(unittest.TestCase):
|
|||||||
self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF)
|
self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF)
|
||||||
self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456)
|
self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
load_uvarint64(b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01"),
|
||||||
|
0xFFFF_FFFF_FFFF_FFFF,
|
||||||
|
)
|
||||||
|
with self.assertRaises(OverflowError):
|
||||||
|
i = load_uvarint64(b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02")
|
||||||
|
print(hex(i))
|
||||||
|
|
||||||
def test_validate_enum(self):
|
def test_validate_enum(self):
|
||||||
# ok message:
|
# ok message:
|
||||||
msg = Failure(code=7)
|
msg = Failure(code=7)
|
||||||
@ -110,6 +125,5 @@ class TestProtobuf(unittest.TestCase):
|
|||||||
self.assertEqual(nmsg.coin_name, "Bitcoin")
|
self.assertEqual(nmsg.coin_name, "Bitcoin")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user