src: run black

pull/25/head
Jan Pochyla 6 years ago
parent ead154b907
commit dcb15f77c3

@ -17,7 +17,14 @@ def addrtype_bytes(address_type: int):
if address_type <= 0xFFFFFF: if address_type <= 0xFFFFFF:
return bytes([(address_type >> 16), (address_type >> 8), (address_type & 0xFF)]) return bytes([(address_type >> 16), (address_type >> 8), (address_type & 0xFF)])
# else # else
return bytes([(address_type >> 24), (address_type >> 16), (address_type >> 8), (address_type & 0xFF)]) return bytes(
[
(address_type >> 24),
(address_type >> 16),
(address_type >> 8),
(address_type & 0xFF),
]
)
def check(address_type, raw_address): def check(address_type, raw_address):
@ -26,25 +33,36 @@ def check(address_type, raw_address):
if address_type <= 0xFFFF: if address_type <= 0xFFFF:
return address_type == (raw_address[0] << 8) | raw_address[1] return address_type == (raw_address[0] << 8) | raw_address[1]
if address_type <= 0xFFFFFF: if address_type <= 0xFFFFFF:
return address_type == (raw_address[0] << 16) | (raw_address[1] << 8) | raw_address[2] return (
address_type
== (raw_address[0] << 16) | (raw_address[1] << 8) | raw_address[2]
)
# else # else
return address_type == (raw_address[0] << 24) | (raw_address[1] << 16) | (raw_address[2] << 8) | raw_address[3] return (
address_type
== (raw_address[0] << 24)
| (raw_address[1] << 16)
| (raw_address[2] << 8)
| raw_address[3]
)
def strip(address_type, raw_address): def strip(address_type, raw_address):
if not check(address_type, raw_address): if not check(address_type, raw_address):
raise ValueError('Invalid address') raise ValueError("Invalid address")
l = length(address_type) l = length(address_type)
return raw_address[l:] return raw_address[l:]
def split(coin, raw_address): def split(coin, raw_address):
for f in ('address_type', for f in (
'address_type_p2sh', "address_type",
'address_type_p2wpkh', "address_type_p2sh",
'address_type_p2wsh'): "address_type_p2wpkh",
"address_type_p2wsh",
):
at = getattr(coin, f) at = getattr(coin, f)
if at is not None and check(at, raw_address): if at is not None and check(at, raw_address):
l = length(at) l = length(at)
return raw_address[:l], raw_address[l:] return raw_address[:l], raw_address[l:]
raise ValueError('Invalid address') raise ValueError("Invalid address")

@ -53,6 +53,6 @@ def set_passphrase(passphrase):
def clear(skip_passphrase: bool = False): def clear(skip_passphrase: bool = False):
set_seed(None) set_seed(None)
if skip_passphrase: if skip_passphrase:
set_passphrase('') set_passphrase("")
else: else:
set_passphrase(None) set_passphrase(None)

@ -1,7 +1,7 @@
from trezor.crypto.base58 import groestl512d_32, sha256d_32 from trezor.crypto.base58 import groestl512d_32, sha256d_32
class CoinInfo:
class CoinInfo:
def __init__( def __init__(
self, self,
coin_name: str, coin_name: str,
@ -37,7 +37,7 @@ class CoinInfo:
self.version_group_id = version_group_id self.version_group_id = version_group_id
self.bip115 = bip115 self.bip115 = bip115
self.curve_name = curve_name self.curve_name = curve_name
if curve_name == 'secp256k1-groestl': if curve_name == "secp256k1-groestl":
self.b58_hash = groestl512d_32 self.b58_hash = groestl512d_32
self.sign_hash_double = False self.sign_hash_double = False
else: else:

@ -19,11 +19,11 @@ def by_address_type(address_type):
for c in COINS: for c in COINS:
if c.address_type == address_type: if c.address_type == address_type:
return c return c
raise ValueError('Unknown coin address type %d' % address_type) raise ValueError("Unknown coin address type %d" % address_type)
def by_slip44(slip44): def by_slip44(slip44):
for c in COINS: for c in COINS:
if c.slip44 == slip44: if c.slip44 == slip44:
return c return c
raise ValueError('Unknown coin slip44 index %d' % slip44) raise ValueError("Unknown coin slip44 index %d" % slip44)

@ -21,7 +21,7 @@ async def hold_to_confirm(ctx, content, code=None, *args, **kwargs):
code = ButtonRequestType.Other code = ButtonRequestType.Other
await ctx.call(ButtonRequest(code=code), MessageType.ButtonAck) await ctx.call(ButtonRequest(code=code), MessageType.ButtonAck)
dialog = HoldToConfirmDialog(content, 'Hold to confirm', *args, **kwargs) dialog = HoldToConfirmDialog(content, "Hold to confirm", *args, **kwargs)
return await ctx.wait(dialog) == CONFIRMED return await ctx.wait(dialog) == CONFIRMED
@ -29,10 +29,10 @@ async def hold_to_confirm(ctx, content, code=None, *args, **kwargs):
async def require_confirm(*args, **kwargs): async def require_confirm(*args, **kwargs):
confirmed = await confirm(*args, **kwargs) confirmed = await confirm(*args, **kwargs)
if not confirmed: if not confirmed:
raise wire.ActionCancelled('Cancelled') raise wire.ActionCancelled("Cancelled")
async def require_hold_to_confirm(*args, **kwargs): async def require_hold_to_confirm(*args, **kwargs):
confirmed = await hold_to_confirm(*args, **kwargs) confirmed = await hold_to_confirm(*args, **kwargs)
if not confirmed: if not confirmed:
raise wire.ActionCancelled('Cancelled') raise wire.ActionCancelled("Cancelled")

@ -12,14 +12,11 @@ from apps.common.confirm import confirm
async def show_address(ctx, address: str): async def show_address(ctx, address: str):
lines = split_address(address) lines = split_address(address)
text = Text('Confirm address', ui.ICON_RECEIVE, icon_color=ui.GREEN) text = Text("Confirm address", ui.ICON_RECEIVE, icon_color=ui.GREEN)
text.mono(*lines) text.mono(*lines)
return await confirm( return await confirm(
ctx, ctx, text, code=ButtonRequestType.Address, cancel="QR", cancel_style=ui.BTN_KEY
text, )
code=ButtonRequestType.Address,
cancel='QR',
cancel_style=ui.BTN_KEY)
async def show_qr(ctx, address: str): async def show_qr(ctx, address: str):
@ -28,14 +25,15 @@ async def show_qr(ctx, address: str):
qr_coef = const(4) qr_coef = const(4)
qr = Qr(address, (qr_x, qr_y), qr_coef) qr = Qr(address, (qr_x, qr_y), qr_coef)
text = Text('Confirm address', ui.ICON_RECEIVE, icon_color=ui.GREEN) text = Text("Confirm address", ui.ICON_RECEIVE, icon_color=ui.GREEN)
content = Container(qr, text) content = Container(qr, text)
return await confirm( return await confirm(
ctx, ctx,
content, content,
code=ButtonRequestType.Address, code=ButtonRequestType.Address,
cancel='Address', cancel="Address",
cancel_style=ui.BTN_KEY) cancel_style=ui.BTN_KEY,
)
def split_address(address: str): def split_address(address: str):

@ -13,16 +13,17 @@ from apps.common.cache import get_state
@ui.layout @ui.layout
async def request_passphrase_entry(ctx): async def request_passphrase_entry(ctx):
text = Text('Enter passphrase', ui.ICON_CONFIG) text = Text("Enter passphrase", ui.ICON_CONFIG)
text.normal('Where to enter your', 'passphrase?') text.normal("Where to enter your", "passphrase?")
text.render() text.render()
ack = await ctx.call( ack = await ctx.call(
ButtonRequest(code=ButtonRequestType.PassphraseType), ButtonRequest(code=ButtonRequestType.PassphraseType),
MessageType.ButtonAck, MessageType.ButtonAck,
MessageType.Cancel) MessageType.Cancel,
)
if ack.MESSAGE_WIRE_TYPE == MessageType.Cancel: if ack.MESSAGE_WIRE_TYPE == MessageType.Cancel:
raise wire.ActionCancelled('Passphrase cancelled') raise wire.ActionCancelled("Passphrase cancelled")
selector = EntrySelector(text) selector = EntrySelector(text)
return await ctx.wait(selector) return await ctx.wait(selector)
@ -31,28 +32,30 @@ async def request_passphrase_entry(ctx):
@ui.layout @ui.layout
async def request_passphrase_ack(ctx, on_device): async def request_passphrase_ack(ctx, on_device):
if not on_device: if not on_device:
text = Text('Passphrase entry', ui.ICON_CONFIG) text = Text("Passphrase entry", ui.ICON_CONFIG)
text.normal('Please, type passphrase', 'on connected host.') text.normal("Please, type passphrase", "on connected host.")
text.render() text.render()
req = PassphraseRequest(on_device=on_device) req = PassphraseRequest(on_device=on_device)
ack = await ctx.call(req, MessageType.PassphraseAck, MessageType.Cancel) ack = await ctx.call(req, MessageType.PassphraseAck, MessageType.Cancel)
if ack.MESSAGE_WIRE_TYPE == MessageType.Cancel: if ack.MESSAGE_WIRE_TYPE == MessageType.Cancel:
raise wire.ActionCancelled('Passphrase cancelled') raise wire.ActionCancelled("Passphrase cancelled")
if on_device: if on_device:
if ack.passphrase is not None: if ack.passphrase is not None:
raise wire.ProcessError('Passphrase provided when it should not be') raise wire.ProcessError("Passphrase provided when it should not be")
keyboard = PassphraseKeyboard('Enter passphrase') keyboard = PassphraseKeyboard("Enter passphrase")
passphrase = await ctx.wait(keyboard) passphrase = await ctx.wait(keyboard)
if passphrase == CANCELLED: if passphrase == CANCELLED:
raise wire.ActionCancelled('Passphrase cancelled') raise wire.ActionCancelled("Passphrase cancelled")
else: else:
if ack.passphrase is None: if ack.passphrase is None:
raise wire.ProcessError('Passphrase not provided') raise wire.ProcessError("Passphrase not provided")
passphrase = ack.passphrase passphrase = ack.passphrase
req = PassphraseStateRequest(state=get_state(prev_state=ack.state, passphrase=passphrase)) req = PassphraseStateRequest(
state=get_state(prev_state=ack.state, passphrase=passphrase)
)
ack = await ctx.call(req, MessageType.PassphraseStateAck, MessageType.Cancel) ack = await ctx.call(req, MessageType.PassphraseStateAck, MessageType.Cancel)
return passphrase return passphrase
@ -71,4 +74,4 @@ async def protect_by_passphrase(ctx):
if storage.has_passphrase(): if storage.has_passphrase():
return await request_passphrase(ctx) return await request_passphrase(ctx)
else: else:
return '' return ""

@ -11,32 +11,31 @@ class PinCancelled(Exception):
@ui.layout @ui.layout
async def request_pin(label=None, cancellable: bool=True) -> str: async def request_pin(label=None, cancellable: bool = True) -> str:
def onchange(): def onchange():
c = dialog.cancel c = dialog.cancel
if matrix.pin: if matrix.pin:
back = res.load(ui.ICON_BACK) back = res.load(ui.ICON_BACK)
if c.content is not back: if c.content is not back:
c.normal_style = ui.BTN_CLEAR['normal'] c.normal_style = ui.BTN_CLEAR["normal"]
c.content = back c.content = back
c.enable() c.enable()
c.taint() c.taint()
else: else:
lock = res.load(ui.ICON_LOCK) lock = res.load(ui.ICON_LOCK)
if not cancellable and c.content: if not cancellable and c.content:
c.content = '' c.content = ""
c.disable() c.disable()
c.taint() c.taint()
elif c.content is not lock: elif c.content is not lock:
c.normal_style = ui.BTN_CANCEL['normal'] c.normal_style = ui.BTN_CANCEL["normal"]
c.content = lock c.content = lock
c.enable() c.enable()
c.taint() c.taint()
c.render() c.render()
if label is None: if label is None:
label = 'Enter your PIN' label = "Enter your PIN"
matrix = PinMatrix(label) matrix = PinMatrix(label)
matrix.onchange = onchange matrix.onchange = onchange
dialog = ConfirmDialog(matrix) dialog = ConfirmDialog(matrix)
@ -54,7 +53,7 @@ async def request_pin(label=None, cancellable: bool=True) -> str:
if result == CONFIRMED: if result == CONFIRMED:
return matrix.pin return matrix.pin
elif matrix.pin: # reset elif matrix.pin: # reset
matrix.change('') matrix.change("")
continue continue
else: # cancel else: # cancel
raise PinCancelled() raise PinCancelled()

@ -4,10 +4,12 @@ from trezor.crypto import bip32, bip39
from apps.common import cache, storage from apps.common import cache, storage
from apps.common.request_passphrase import protect_by_passphrase from apps.common.request_passphrase import protect_by_passphrase
_DEFAULT_CURVE = 'secp256k1' _DEFAULT_CURVE = "secp256k1"
async def derive_node(ctx: wire.Context, path: list, curve_name: str = _DEFAULT_CURVE) -> bip32.HDNode: async def derive_node(
ctx: wire.Context, path: list, curve_name: str = _DEFAULT_CURVE
) -> bip32.HDNode:
seed = await _get_cached_seed(ctx) seed = await _get_cached_seed(ctx)
node = bip32.from_seed(seed, curve_name) node = bip32.from_seed(seed, curve_name)
node.derive_path(path) node.derive_path(path)
@ -16,7 +18,7 @@ async def derive_node(ctx: wire.Context, path: list, curve_name: str = _DEFAULT_
async def _get_cached_seed(ctx: wire.Context) -> bytes: async def _get_cached_seed(ctx: wire.Context) -> bytes:
if not storage.is_initialized(): if not storage.is_initialized():
raise wire.ProcessError('Device is not initialized') raise wire.ProcessError("Device is not initialized")
if cache.get_seed() is None: if cache.get_seed() is None:
passphrase = await _get_cached_passphrase(ctx) passphrase = await _get_cached_passphrase(ctx)
seed = bip39.seed(storage.get_mnemonic(), passphrase) seed = bip39.seed(storage.get_mnemonic(), passphrase)
@ -31,11 +33,13 @@ async def _get_cached_passphrase(ctx: wire.Context) -> str:
return cache.get_passphrase() return cache.get_passphrase()
def derive_node_without_passphrase(path: list, curve_name: str = _DEFAULT_CURVE) -> bip32.HDNode: def derive_node_without_passphrase(
path: list, curve_name: str = _DEFAULT_CURVE
) -> bip32.HDNode:
if not storage.is_initialized(): if not storage.is_initialized():
raise Exception('Device is not initialized') raise Exception("Device is not initialized")
seed = bip39.seed(storage.get_mnemonic(), '') seed = bip39.seed(storage.get_mnemonic(), "")
node = bip32.from_seed(seed, curve_name) node = bip32.from_seed(seed, curve_name)
node.derive_path(path) node.derive_path(path)
return node return node

@ -21,7 +21,6 @@ def message_digest(coin, message):
def split_message(message): def split_message(message):
def measure(s): def measure(s):
return ui.display.text_width(s, ui.NORMAL) return ui.display.text_width(s, ui.NORMAL)

@ -8,7 +8,7 @@ from apps.common import cache
HOMESCREEN_MAXSIZE = 16384 HOMESCREEN_MAXSIZE = 16384
_STORAGE_VERSION = b'\x01' _STORAGE_VERSION = b"\x01"
# fmt: off # fmt: off
_APP = const(0x01) # app namespace _APP = const(0x01) # app namespace
@ -64,9 +64,9 @@ def load_mnemonic(mnemonic: str, needs_backup: bool) -> None:
config.set(_APP, _MNEMONIC, mnemonic.encode()) config.set(_APP, _MNEMONIC, mnemonic.encode())
config.set(_APP, _VERSION, _STORAGE_VERSION) config.set(_APP, _VERSION, _STORAGE_VERSION)
if needs_backup: if needs_backup:
config.set(_APP, _NEEDS_BACKUP, b'\x01') config.set(_APP, _NEEDS_BACKUP, b"\x01")
else: else:
config.set(_APP, _NEEDS_BACKUP, b'') config.set(_APP, _NEEDS_BACKUP, b"")
def needs_backup() -> bool: def needs_backup() -> bool:
@ -74,7 +74,7 @@ def needs_backup() -> bool:
def set_backed_up() -> None: def set_backed_up() -> None:
config.set(_APP, _NEEDS_BACKUP, b'') config.set(_APP, _NEEDS_BACKUP, b"")
def unfinished_backup() -> bool: def unfinished_backup() -> bool:
@ -83,34 +83,39 @@ def unfinished_backup() -> bool:
def set_unfinished_backup(state: bool) -> None: def set_unfinished_backup(state: bool) -> None:
if state: if state:
config.set(_APP, _UNFINISHED_BACKUP, b'\x01') config.set(_APP, _UNFINISHED_BACKUP, b"\x01")
else: else:
config.set(_APP, _UNFINISHED_BACKUP, b'') config.set(_APP, _UNFINISHED_BACKUP, b"")
def get_passphrase_source() -> int: def get_passphrase_source() -> int:
b = config.get(_APP, _PASSPHRASE_SOURCE) b = config.get(_APP, _PASSPHRASE_SOURCE)
if b == b'\x01': if b == b"\x01":
return 1 return 1
elif b == b'\x02': elif b == b"\x02":
return 2 return 2
else: else:
return 0 return 0
def load_settings(label: str=None, use_passphrase: bool=None, homescreen: bytes=None, passphrase_source: int=None) -> None: def load_settings(
label: str = None,
use_passphrase: bool = None,
homescreen: bytes = None,
passphrase_source: int = None,
) -> None:
if label is not None: if label is not None:
config.set(_APP, _LABEL, label.encode(), True) # public config.set(_APP, _LABEL, label.encode(), True) # public
if use_passphrase is True: if use_passphrase is True:
config.set(_APP, _USE_PASSPHRASE, b'\x01') config.set(_APP, _USE_PASSPHRASE, b"\x01")
if use_passphrase is False: if use_passphrase is False:
config.set(_APP, _USE_PASSPHRASE, b'') config.set(_APP, _USE_PASSPHRASE, b"")
if homescreen is not None: if homescreen is not None:
if homescreen[:8] == b'TOIf\x90\x00\x90\x00': if homescreen[:8] == b"TOIf\x90\x00\x90\x00":
if len(homescreen) <= HOMESCREEN_MAXSIZE: if len(homescreen) <= HOMESCREEN_MAXSIZE:
config.set(_APP, _HOMESCREEN, homescreen, True) # public config.set(_APP, _HOMESCREEN, homescreen, True) # public
else: else:
config.set(_APP, _HOMESCREEN, b'', True) # public config.set(_APP, _HOMESCREEN, b"", True) # public
if passphrase_source is not None: if passphrase_source is not None:
if passphrase_source in [0, 1, 2]: if passphrase_source in [0, 1, 2]:
config.set(_APP, _PASSPHRASE_SOURCE, bytes([passphrase_source])) config.set(_APP, _PASSPHRASE_SOURCE, bytes([passphrase_source]))
@ -121,7 +126,7 @@ def get_flags() -> int:
if b is None: if b is None:
return 0 return 0
else: else:
return int.from_bytes(b, 'big') return int.from_bytes(b, "big")
def set_flags(flags: int) -> None: def set_flags(flags: int) -> None:
@ -129,10 +134,10 @@ def set_flags(flags: int) -> None:
if b is None: if b is None:
b = 0 b = 0
else: else:
b = int.from_bytes(b, 'big') b = int.from_bytes(b, "big")
flags = (flags | b) & 0xFFFFFFFF flags = (flags | b) & 0xFFFFFFFF
if flags != b: if flags != b:
config.set(_APP, _FLAGS, flags.to_bytes(4, 'big')) config.set(_APP, _FLAGS, flags.to_bytes(4, "big"))
def get_autolock_delay_ms() -> int: def get_autolock_delay_ms() -> int:
@ -140,13 +145,13 @@ def get_autolock_delay_ms() -> int:
if b is None: if b is None:
return 10 * 60 * 1000 return 10 * 60 * 1000
else: else:
return int.from_bytes(b, 'big') return int.from_bytes(b, "big")
def set_autolock_delay_ms(delay_ms: int) -> None: def set_autolock_delay_ms(delay_ms: int) -> None:
if delay_ms < 60 * 1000: if delay_ms < 60 * 1000:
delay_ms = 60 * 1000 delay_ms = 60 * 1000
config.set(_APP, _AUTOLOCK_DELAY_MS, delay_ms.to_bytes(4, 'big')) config.set(_APP, _AUTOLOCK_DELAY_MS, delay_ms.to_bytes(4, "big"))
def next_u2f_counter() -> int: def next_u2f_counter() -> int:
@ -154,13 +159,13 @@ def next_u2f_counter() -> int:
if b is None: if b is None:
b = 0 b = 0
else: else:
b = int.from_bytes(b, 'big') + 1 b = int.from_bytes(b, "big") + 1
set_u2f_counter(b) set_u2f_counter(b)
return b return b
def set_u2f_counter(cntr: int): def set_u2f_counter(cntr: int):
config.set(_APP, _U2F_COUNTER, cntr.to_bytes(4, 'big')) config.set(_APP, _U2F_COUNTER, cntr.to_bytes(4, "big"))
def wipe(): def wipe():

@ -1,6 +1,7 @@
if not __debug__: if not __debug__:
from trezor.utils import halt from trezor.utils import halt
halt('debug mode inactive')
halt("debug mode inactive")
if __debug__: if __debug__:
from trezor import loop from trezor import loop
@ -33,12 +34,16 @@ if __debug__:
m.reset_word_pos = reset_word_index m.reset_word_pos = reset_word_index
m.reset_entropy = reset_internal_entropy m.reset_entropy = reset_internal_entropy
if reset_current_words: if reset_current_words:
m.reset_word = ' '.join(reset_current_words) m.reset_word = " ".join(reset_current_words)
return m return m
def boot(): def boot():
# wipe storage when debug build is used # wipe storage when debug build is used
storage.wipe() storage.wipe()
register(MessageType.DebugLinkDecision, protobuf_workflow, dispatch_DebugLinkDecision) register(
register(MessageType.DebugLinkGetState, protobuf_workflow, dispatch_DebugLinkGetState) MessageType.DebugLinkDecision, protobuf_workflow, dispatch_DebugLinkDecision
)
register(
MessageType.DebugLinkGetState, protobuf_workflow, dispatch_DebugLinkGetState
)

@ -9,21 +9,25 @@ from trezor.wire import protobuf_workflow, register
def dispatch_EthereumGetAddress(*args, **kwargs): def dispatch_EthereumGetAddress(*args, **kwargs):
from .get_address import ethereum_get_address from .get_address import ethereum_get_address
return ethereum_get_address(*args, **kwargs) return ethereum_get_address(*args, **kwargs)
def dispatch_EthereumSignTx(*args, **kwargs): def dispatch_EthereumSignTx(*args, **kwargs):
from .sign_tx import ethereum_sign_tx from .sign_tx import ethereum_sign_tx
return ethereum_sign_tx(*args, **kwargs) return ethereum_sign_tx(*args, **kwargs)
def dispatch_EthereumSignMessage(*args, **kwargs): def dispatch_EthereumSignMessage(*args, **kwargs):
from .sign_message import ethereum_sign_message from .sign_message import ethereum_sign_message
return ethereum_sign_message(*args, **kwargs) return ethereum_sign_message(*args, **kwargs)
def dispatch_EthereumVerifyMessage(*args, **kwargs): def dispatch_EthereumVerifyMessage(*args, **kwargs):
from .verify_message import ethereum_verify_message from .verify_message import ethereum_verify_message
return ethereum_verify_message(*args, **kwargs) return ethereum_verify_message(*args, **kwargs)

@ -37,18 +37,18 @@ def _ethereum_address_hex(address, network=None):
hx = hexlify(address).decode() hx = hexlify(address).decode()
prefix = str(network.chain_id) + '0x' if rskip60 else '' prefix = str(network.chain_id) + "0x" if rskip60 else ""
hs = sha3_256(prefix + hx).digest(True) hs = sha3_256(prefix + hx).digest(True)
h = '' h = ""
for i in range(20): for i in range(20):
l = hx[i * 2] l = hx[i * 2]
if hs[i] & 0x80 and l >= 'a' and l <= 'f': if hs[i] & 0x80 and l >= "a" and l <= "f":
l = l.upper() l = l.upper()
h += l h += l
l = hx[i * 2 + 1] l = hx[i * 2 + 1]
if hs[i] & 0x08 and l >= 'a' and l <= 'f': if hs[i] & 0x08 and l >= "a" and l <= "f":
l = l.upper() l = l.upper()
h += l h += l
return '0x' + h return "0x" + h

@ -14,21 +14,23 @@ async def require_confirm_tx(ctx, to, value, chain_id, token=None, tx_type=None)
if to: if to:
to_str = _ethereum_address_hex(to, networks.by_chain_id(chain_id)) to_str = _ethereum_address_hex(to, networks.by_chain_id(chain_id))
else: else:
to_str = 'new contract?' to_str = "new contract?"
text = Text('Confirm sending', ui.ICON_SEND, icon_color=ui.GREEN) text = Text("Confirm sending", ui.ICON_SEND, icon_color=ui.GREEN)
text.bold(format_ethereum_amount(value, token, chain_id, tx_type)) text.bold(format_ethereum_amount(value, token, chain_id, tx_type))
text.normal('to') text.normal("to")
text.mono(*split_address(to_str)) text.mono(*split_address(to_str))
# we use SignTx, not ConfirmOutput, for compatibility with T1 # we use SignTx, not ConfirmOutput, for compatibility with T1
await require_confirm(ctx, text, ButtonRequestType.SignTx) await require_confirm(ctx, text, ButtonRequestType.SignTx)
async def require_confirm_fee(ctx, spending, gas_price, gas_limit, chain_id, token=None, tx_type=None): async def require_confirm_fee(
text = Text('Confirm transaction', ui.ICON_SEND, icon_color=ui.GREEN) ctx, spending, gas_price, gas_limit, chain_id, token=None, tx_type=None
):
text = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
text.bold(format_ethereum_amount(spending, token, chain_id, tx_type)) text.bold(format_ethereum_amount(spending, token, chain_id, tx_type))
text.normal('Gas price:') text.normal("Gas price:")
text.bold(format_ethereum_amount(gas_price, None, chain_id, tx_type)) text.bold(format_ethereum_amount(gas_price, None, chain_id, tx_type))
text.normal('Maximum fee:') text.normal("Maximum fee:")
text.bold(format_ethereum_amount(gas_price * gas_limit, None, chain_id, tx_type)) text.bold(format_ethereum_amount(gas_price * gas_limit, None, chain_id, tx_type))
await require_hold_to_confirm(ctx, text, ButtonRequestType.SignTx) await require_hold_to_confirm(ctx, text, ButtonRequestType.SignTx)
@ -40,9 +42,9 @@ def split_data(data):
async def require_confirm_data(ctx, data, data_total): async def require_confirm_data(ctx, data, data_total):
data_str = hexlify(data[:36]).decode() data_str = hexlify(data[:36]).decode()
if data_total > 36: if data_total > 36:
data_str = data_str[:-2] + '..' data_str = data_str[:-2] + ".."
text = Text('Confirm data', ui.ICON_SEND, icon_color=ui.GREEN) text = Text("Confirm data", ui.ICON_SEND, icon_color=ui.GREEN)
text.bold('Size: %d bytes' % data_total) text.bold("Size: %d bytes" % data_total)
text.mono(*split_data(data_str)) text.mono(*split_data(data_str))
# we use SignTx, not ConfirmOutput, for compatibility with T1 # we use SignTx, not ConfirmOutput, for compatibility with T1
await require_confirm(ctx, text, ButtonRequestType.SignTx) await require_confirm(ctx, text, ButtonRequestType.SignTx)
@ -55,7 +57,7 @@ def split_address(address):
def format_ethereum_amount(value: int, token, chain_id: int, tx_type=None): def format_ethereum_amount(value: int, token, chain_id: int, tx_type=None):
if token: if token:
if token is tokens.UNKNOWN_TOKEN: if token is tokens.UNKNOWN_TOKEN:
return 'Unknown token value' return "Unknown token value"
suffix = token[2] suffix = token[2]
decimals = token[3] decimals = token[3]
else: else:
@ -63,7 +65,7 @@ def format_ethereum_amount(value: int, token, chain_id: int, tx_type=None):
decimals = 18 decimals = 18
if value <= 1e9: if value <= 1e9:
suffix = 'Wei ' + suffix suffix = "Wei " + suffix
decimals = 0 decimals = 0
return '%s %s' % (format_amount(value, decimals), suffix) return "%s %s" % (format_amount(value, decimals), suffix)

@ -1,9 +1,9 @@
def shortcut_by_chain_id(chain_id, tx_type=None): def shortcut_by_chain_id(chain_id, tx_type=None):
if tx_type in [1, 6] and chain_id in [1, 3]: if tx_type in [1, 6] and chain_id in [1, 3]:
return 'WAN' return "WAN"
else: else:
n = by_chain_id(chain_id) n = by_chain_id(chain_id)
return n.shortcut if n is not None else 'UNKN' return n.shortcut if n is not None else "UNKN"
def by_chain_id(chain_id): def by_chain_id(chain_id):
@ -21,14 +21,8 @@ def by_slip44(slip44):
class NetworkInfo: class NetworkInfo:
def __init__( def __init__(
self, self, chain_id: int, slip44: int, shortcut: str, name: str, rskip60: bool
chain_id: int,
slip44: int,
shortcut: str,
name: str,
rskip60: bool
): ):
self.chain_id = chain_id self.chain_id = chain_id
self.slip44 = slip44 self.slip44 = slip44

@ -12,7 +12,7 @@ from apps.common.signverify import split_message
def message_digest(message): def message_digest(message):
h = HashWriter(sha3_256) h = HashWriter(sha3_256)
signed_message_header = '\x19Ethereum Signed Message:\n' signed_message_header = "\x19Ethereum Signed Message:\n"
h.extend(signed_message_header) h.extend(signed_message_header)
h.extend(str(len(message))) h.extend(str(len(message)))
h.extend(message) h.extend(message)
@ -37,6 +37,6 @@ async def ethereum_sign_message(ctx, msg):
async def require_confirm_sign_message(ctx, message): async def require_confirm_sign_message(ctx, message):
message = split_message(message) message = split_message(message)
text = Text('Sign ETH message') text = Text("Sign ETH message")
text.normal(*message) text.normal(*message)
await require_confirm(ctx, text) await require_confirm(ctx, text)

@ -26,21 +26,32 @@ async def ethereum_sign_tx(ctx, msg):
# detect ERC - 20 token # detect ERC - 20 token
token = None token = None
recipient = msg.to recipient = msg.to
value = int.from_bytes(msg.value, 'big') value = int.from_bytes(msg.value, "big")
if len(msg.to) == 20 and \ if (
len(msg.value) == 0 and \ len(msg.to) == 20
data_total == 68 and \ and len(msg.value) == 0
len(msg.data_initial_chunk) == 68 and \ and data_total == 68
msg.data_initial_chunk[:16] == b'\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00': and len(msg.data_initial_chunk) == 68
and msg.data_initial_chunk[:16]
== b"\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
):
token = tokens.token_by_chain_address(msg.chain_id, msg.to) token = tokens.token_by_chain_address(msg.chain_id, msg.to)
recipient = msg.data_initial_chunk[16:36] recipient = msg.data_initial_chunk[16:36]
value = int.from_bytes(msg.data_initial_chunk[36:68], 'big') value = int.from_bytes(msg.data_initial_chunk[36:68], "big")
await require_confirm_tx(ctx, recipient, value, msg.chain_id, token, msg.tx_type) await require_confirm_tx(ctx, recipient, value, msg.chain_id, token, msg.tx_type)
if token is None and msg.data_length > 0: if token is None and msg.data_length > 0:
await require_confirm_data(ctx, msg.data_initial_chunk, data_total) await require_confirm_data(ctx, msg.data_initial_chunk, data_total)
await require_confirm_fee(ctx, value, int.from_bytes(msg.gas_price, 'big'), int.from_bytes(msg.gas_limit, 'big'), msg.chain_id, token, msg.tx_type) await require_confirm_fee(
ctx,
value,
int.from_bytes(msg.gas_price, "big"),
int.from_bytes(msg.gas_limit, "big"),
msg.chain_id,
token,
msg.tx_type,
)
data = bytearray() data = bytearray()
data += msg.data_initial_chunk data += msg.data_initial_chunk
@ -97,6 +108,7 @@ def get_total_length(msg: EthereumSignTx, data_total: int) -> int:
async def send_request_chunk(ctx, data_left: int): async def send_request_chunk(ctx, data_left: int):
from trezor.messages.MessageType import EthereumTxAck from trezor.messages.MessageType import EthereumTxAck
# TODO: layoutProgress ? # TODO: layoutProgress ?
req = EthereumTxRequest() req = EthereumTxRequest()
if data_left <= 1024: if data_left <= 1024:
@ -129,24 +141,24 @@ async def send_signature(ctx, msg: EthereumSignTx, digest):
def check(msg: EthereumSignTx): def check(msg: EthereumSignTx):
if msg.tx_type not in [1, 6, None]: if msg.tx_type not in [1, 6, None]:
raise wire.DataError('tx_type out of bounds') raise wire.DataError("tx_type out of bounds")
if msg.chain_id < 0 or msg.chain_id > MAX_CHAIN_ID: if msg.chain_id < 0 or msg.chain_id > MAX_CHAIN_ID:
raise wire.DataError('chain_id out of bounds') raise wire.DataError("chain_id out of bounds")
if msg.data_length > 0: if msg.data_length > 0:
if not msg.data_initial_chunk: if not msg.data_initial_chunk:
raise wire.DataError('Data length provided, but no initial chunk') raise wire.DataError("Data length provided, but no initial chunk")
# Our encoding only supports transactions up to 2^24 bytes. To # Our encoding only supports transactions up to 2^24 bytes. To
# prevent exceeding the limit we use a stricter limit on data length. # prevent exceeding the limit we use a stricter limit on data length.
if msg.data_length > 16000000: if msg.data_length > 16000000:
raise wire.DataError('Data length exceeds limit') raise wire.DataError("Data length exceeds limit")
if len(msg.data_initial_chunk) > msg.data_length: if len(msg.data_initial_chunk) > msg.data_length:
raise wire.DataError('Invalid size of initial chunk') raise wire.DataError("Invalid size of initial chunk")
# safety checks # safety checks
if not check_gas(msg) or not check_to(msg): if not check_gas(msg) or not check_to(msg):
raise wire.DataError('Safety check failed') raise wire.DataError("Safety check failed")
def check_gas(msg: EthereumSignTx) -> bool: def check_gas(msg: EthereumSignTx) -> bool:
@ -159,7 +171,7 @@ def check_gas(msg: EthereumSignTx) -> bool:
def check_to(msg: EthereumTxRequest) -> bool: def check_to(msg: EthereumTxRequest) -> bool:
if msg.to == b'': if msg.to == b"":
if msg.data_length == 0: if msg.data_length == 0:
# sending transaction to address 0 (contract creation) without a data field # sending transaction to address 0 (contract creation) without a data field
return False return False
@ -171,15 +183,15 @@ def check_to(msg: EthereumTxRequest) -> bool:
def sanitize(msg): def sanitize(msg):
if msg.value is None: if msg.value is None:
msg.value = b'' msg.value = b""
if msg.data_initial_chunk is None: if msg.data_initial_chunk is None:
msg.data_initial_chunk = b'' msg.data_initial_chunk = b""
if msg.data_length is None: if msg.data_length is None:
msg.data_length = 0 msg.data_length = 0
if msg.to is None: if msg.to is None:
msg.to = b'' msg.to = b""
if msg.nonce is None: if msg.nonce is None:
msg.nonce = b'' msg.nonce = b""
if msg.chain_id is None: if msg.chain_id is None:
msg.chain_id = 0 msg.chain_id = 0
return msg return msg

File diff suppressed because it is too large Load Diff

@ -18,25 +18,25 @@ async def ethereum_verify_message(ctx, msg):
pubkey = secp256k1.verify_recover(sig, digest) pubkey = secp256k1.verify_recover(sig, digest)
if not pubkey: if not pubkey:
raise ValueError('Invalid signature') raise ValueError("Invalid signature")
pkh = sha3_256(pubkey[1:]).digest(True)[-20:] pkh = sha3_256(pubkey[1:]).digest(True)[-20:]
if msg.address != pkh: if msg.address != pkh:
raise ValueError('Invalid signature') raise ValueError("Invalid signature")
address = '0x' + hexlify(msg.address).decode() address = "0x" + hexlify(msg.address).decode()
await require_confirm_verify_message(ctx, address, msg.message) await require_confirm_verify_message(ctx, address, msg.message)
return Success(message='Message verified') return Success(message="Message verified")
async def require_confirm_verify_message(ctx, address, message): async def require_confirm_verify_message(ctx, address, message):
text = Text('Confirm address') text = Text("Confirm address")
text.mono(*split_address(address)) text.mono(*split_address(address))
await require_confirm(ctx, text) await require_confirm(ctx, text)
text = Text('Verify message') text = Text("Verify message")
text.mono(*split_message(message)) text.mono(*split_message(message))
await require_confirm(ctx, text) await require_confirm(ctx, text)

@ -18,17 +18,17 @@ _TYPE_INIT = const(0x80) # initial frame identifier
_TYPE_CONT = const(0x00) # continuation frame identifier _TYPE_CONT = const(0x00) # continuation frame identifier
# types of cmd # types of cmd
_CMD_PING = const(0x81) # echo data through local processor only _CMD_PING = const(0x81) # echo data through local processor only
_CMD_MSG = const(0x83) # send U2F message frame _CMD_MSG = const(0x83) # send U2F message frame
_CMD_LOCK = const(0x84) # send lock channel command _CMD_LOCK = const(0x84) # send lock channel command
_CMD_INIT = const(0x86) # channel initialization _CMD_INIT = const(0x86) # channel initialization
_CMD_WINK = const(0x88) # send device identification wink _CMD_WINK = const(0x88) # send device identification wink
_CMD_ERROR = const(0xbf) # error response _CMD_ERROR = const(0xbf) # error response
# types for the msg cmd # types for the msg cmd
_MSG_REGISTER = const(0x01) # registration command _MSG_REGISTER = const(0x01) # registration command
_MSG_AUTHENTICATE = const(0x02) # authenticate/sign command _MSG_AUTHENTICATE = const(0x02) # authenticate/sign command
_MSG_VERSION = const(0x03) # read version string command _MSG_VERSION = const(0x03) # read version string command
# hid error codes # hid error codes
_ERR_NONE = const(0x00) # no error _ERR_NONE = const(0x00) # no error
@ -52,8 +52,8 @@ _SW_INS_NOT_SUPPORTED = const(0x6d00)
_SW_CLA_NOT_SUPPORTED = const(0x6e00) _SW_CLA_NOT_SUPPORTED = const(0x6e00)
# init response # init response
_CAPFLAG_WINK = const(0x01) # device supports _CMD_WINK _CAPFLAG_WINK = const(0x01) # device supports _CMD_WINK
_U2FHID_IF_VERSION = const(2) # interface version _U2FHID_IF_VERSION = const(2) # interface version
# register response # register response
_U2F_KEY_PATH = const(0x80553246) _U2F_KEY_PATH = const(0x80553246)
@ -63,18 +63,18 @@ _U2F_ATT_CERT = b"0\x82\x01\x180\x81\xc0\x02\t\x00\xb1\xd9\x8fBdr\xd3,0\n\x06\x0
_BOGUS_APPID = b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" _BOGUS_APPID = b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
# authentication control byte # authentication control byte
_AUTH_ENFORCE = const(0x03) # enforce user presence and sign _AUTH_ENFORCE = const(0x03) # enforce user presence and sign
_AUTH_CHECK_ONLY = const(0x07) # check only _AUTH_CHECK_ONLY = const(0x07) # check only
_AUTH_FLAG_TUP = const(0x01) # test of user presence set _AUTH_FLAG_TUP = const(0x01) # test of user presence set
# common raw message format (ISO7816-4:2005 mapping) # common raw message format (ISO7816-4:2005 mapping)
_APDU_CLA = const(0) # uint8_t cla; // Class - reserved _APDU_CLA = const(0) # uint8_t cla; // Class - reserved
_APDU_INS = const(1) # uint8_t ins; // U2F instruction _APDU_INS = const(1) # uint8_t ins; // U2F instruction
_APDU_P1 = const(2) # uint8_t p1; // U2F parameter 1 _APDU_P1 = const(2) # uint8_t p1; // U2F parameter 1
_APDU_P2 = const(3) # uint8_t p2; // U2F parameter 2 _APDU_P2 = const(3) # uint8_t p2; // U2F parameter 2
_APDU_LC1 = const(4) # uint8_t lc1; // Length field, set to zero _APDU_LC1 = const(4) # uint8_t lc1; // Length field, set to zero
_APDU_LC2 = const(5) # uint8_t lc2; // Length field, MSB _APDU_LC2 = const(5) # uint8_t lc2; // Length field, MSB
_APDU_LC3 = const(6) # uint8_t lc3; // Length field, LSB _APDU_LC3 = const(6) # uint8_t lc3; // Length field, LSB
_APDU_DATA = const(7) # uint8_t data[1]; // Data field _APDU_DATA = const(7) # uint8_t data[1]; // Data field
@ -85,10 +85,10 @@ def frame_init() -> dict:
# uint8_t bcntl; // Message byte count - low part # uint8_t bcntl; // Message byte count - low part
# uint8_t data[HID_RPT_SIZE - 7]; // Data payload # uint8_t data[HID_RPT_SIZE - 7]; // Data payload
return { return {
'cid': 0 | uctypes.UINT32, "cid": 0 | uctypes.UINT32,
'cmd': 4 | uctypes.UINT8, "cmd": 4 | uctypes.UINT8,
'bcnt': 5 | uctypes.UINT16, "bcnt": 5 | uctypes.UINT16,
'data': (7 | uctypes.ARRAY, (_HID_RPT_SIZE - 7) | uctypes.UINT8), "data": (7 | uctypes.ARRAY, (_HID_RPT_SIZE - 7) | uctypes.UINT8),
} }
@ -97,9 +97,9 @@ def frame_cont() -> dict:
# uint8_t seq; // Sequence number - b7 cleared # uint8_t seq; // Sequence number - b7 cleared
# uint8_t data[HID_RPT_SIZE - 5]; // Data payload # uint8_t data[HID_RPT_SIZE - 5]; // Data payload
return { return {
'cid': 0 | uctypes.UINT32, "cid": 0 | uctypes.UINT32,
'seq': 4 | uctypes.UINT8, "seq": 4 | uctypes.UINT8,
'data': (5 | uctypes.ARRAY, (_HID_RPT_SIZE - 5) | uctypes.UINT8), "data": (5 | uctypes.ARRAY, (_HID_RPT_SIZE - 5) | uctypes.UINT8),
} }
@ -112,13 +112,13 @@ def resp_cmd_init() -> dict:
# uint8_t versionBuild; // Build version number # uint8_t versionBuild; // Build version number
# uint8_t capFlags; // Capabilities flags # uint8_t capFlags; // Capabilities flags
return { return {
'nonce': (0 | uctypes.ARRAY, 8 | uctypes.UINT8), "nonce": (0 | uctypes.ARRAY, 8 | uctypes.UINT8),
'cid': 8 | uctypes.UINT32, "cid": 8 | uctypes.UINT32,
'versionInterface': 12 | uctypes.UINT8, "versionInterface": 12 | uctypes.UINT8,
'versionMajor': 13 | uctypes.UINT8, "versionMajor": 13 | uctypes.UINT8,
'versionMinor': 14 | uctypes.UINT8, "versionMinor": 14 | uctypes.UINT8,
'versionBuild': 15 | uctypes.UINT8, "versionBuild": 15 | uctypes.UINT8,
'capFlags': 16 | uctypes.UINT8, "capFlags": 16 | uctypes.UINT8,
} }
@ -134,13 +134,13 @@ def resp_cmd_register(khlen: int, certlen: int, siglen: int) -> dict:
# uint8_t sig[siglen]; // Registration signature # uint8_t sig[siglen]; // Registration signature
# uint16_t status; # uint16_t status;
return { return {
'registerId': 0 | uctypes.UINT8, "registerId": 0 | uctypes.UINT8,
'pubKey': (1 | uctypes.ARRAY, 65 | uctypes.UINT8), "pubKey": (1 | uctypes.ARRAY, 65 | uctypes.UINT8),
'keyHandleLen': 66 | uctypes.UINT8, "keyHandleLen": 66 | uctypes.UINT8,
'keyHandle': (67 | uctypes.ARRAY, khlen | uctypes.UINT8), "keyHandle": (67 | uctypes.ARRAY, khlen | uctypes.UINT8),
'cert': (cert_ofs | uctypes.ARRAY, certlen | uctypes.UINT8), "cert": (cert_ofs | uctypes.ARRAY, certlen | uctypes.UINT8),
'sig': (sig_ofs | uctypes.ARRAY, siglen | uctypes.UINT8), "sig": (sig_ofs | uctypes.ARRAY, siglen | uctypes.UINT8),
'status': status_ofs | uctypes.UINT16, "status": status_ofs | uctypes.UINT16,
} }
@ -154,10 +154,10 @@ def req_cmd_authenticate(khlen: int) -> dict:
# uint8_t keyHandleLen; // Length of key handle # uint8_t keyHandleLen; // Length of key handle
# uint8_t keyHandle[khlen]; // Key handle # uint8_t keyHandle[khlen]; // Key handle
return { return {
'chal': (0 | uctypes.ARRAY, 32 | uctypes.UINT8), "chal": (0 | uctypes.ARRAY, 32 | uctypes.UINT8),
'appId': (32 | uctypes.ARRAY, 32 | uctypes.UINT8), "appId": (32 | uctypes.ARRAY, 32 | uctypes.UINT8),
'keyHandleLen': 64 | uctypes.UINT8, "keyHandleLen": 64 | uctypes.UINT8,
'keyHandle': (65 | uctypes.ARRAY, khlen | uctypes.UINT8), "keyHandle": (65 | uctypes.ARRAY, khlen | uctypes.UINT8),
} }
@ -168,17 +168,17 @@ def resp_cmd_authenticate(siglen: int) -> dict:
# uint8_t sig[siglen]; // Signature # uint8_t sig[siglen]; // Signature
# uint16_t status; # uint16_t status;
return { return {
'flags': 0 | uctypes.UINT8, "flags": 0 | uctypes.UINT8,
'ctr': 1 | uctypes.UINT32, "ctr": 1 | uctypes.UINT32,
'sig': (5 | uctypes.ARRAY, siglen | uctypes.UINT8), "sig": (5 | uctypes.ARRAY, siglen | uctypes.UINT8),
'status': status_ofs | uctypes.UINT16, "status": status_ofs | uctypes.UINT16,
} }
def overlay_struct(buf, desc): def overlay_struct(buf, desc):
desc_size = uctypes.sizeof(desc, uctypes.BIG_ENDIAN) desc_size = uctypes.sizeof(desc, uctypes.BIG_ENDIAN)
if desc_size > len(buf): if desc_size > len(buf):
raise ValueError('desc is too big (%d > %d)' % (desc_size, len(buf))) raise ValueError("desc is too big (%d > %d)" % (desc_size, len(buf)))
return uctypes.struct(uctypes.addressof(buf), desc, uctypes.BIG_ENDIAN) return uctypes.struct(uctypes.addressof(buf), desc, uctypes.BIG_ENDIAN)
@ -189,8 +189,9 @@ def make_struct(desc):
class Msg: class Msg:
def __init__(
def __init__(self, cid: int, cla: int, ins: int, p1: int, p2: int, lc: int, data: bytes) -> None: self, cid: int, cla: int, ins: int, p1: int, p2: int, lc: int, data: bytes
) -> None:
self.cid = cid self.cid = cid
self.cla = cla self.cla = cla
self.ins = ins self.ins = ins
@ -201,7 +202,6 @@ class Msg:
class Cmd: class Cmd:
def __init__(self, cid: int, cmd: int, data: bytes) -> None: def __init__(self, cid: int, cmd: int, data: bytes) -> None:
self.cid = cid self.cid = cid
self.cmd = cmd self.cmd = cmd
@ -212,10 +212,12 @@ class Cmd:
ins = self.data[_APDU_INS] ins = self.data[_APDU_INS]
p1 = self.data[_APDU_P1] p1 = self.data[_APDU_P1]
p2 = self.data[_APDU_P2] p2 = self.data[_APDU_P2]
lc = (self.data[_APDU_LC1] << 16) + \ lc = (
(self.data[_APDU_LC2] << 8) + \ (self.data[_APDU_LC1] << 16)
(self.data[_APDU_LC3]) + (self.data[_APDU_LC2] << 8)
data = self.data[_APDU_DATA:_APDU_DATA + lc] + (self.data[_APDU_LC3])
)
data = self.data[_APDU_DATA : _APDU_DATA + lc]
return Msg(self.cid, cla, ins, p1, p2, lc, data) return Msg(self.cid, cla, ins, p1, p2, lc, data)
@ -235,7 +237,7 @@ async def read_cmd(iface: io.HID) -> Cmd:
if ifrm.cmd & _TYPE_MASK == _TYPE_CONT: if ifrm.cmd & _TYPE_MASK == _TYPE_CONT:
# unexpected cont packet, abort current msg # unexpected cont packet, abort current msg
if __debug__: if __debug__:
log.warning(__name__, '_TYPE_CONT') log.warning(__name__, "_TYPE_CONT")
return None return None
if datalen < bcnt: if datalen < bcnt:
@ -253,13 +255,13 @@ async def read_cmd(iface: io.HID) -> Cmd:
if cfrm.seq == _CMD_INIT: if cfrm.seq == _CMD_INIT:
# _CMD_INIT frame, cancels current channel # _CMD_INIT frame, cancels current channel
ifrm = overlay_struct(buf, desc_init) ifrm = overlay_struct(buf, desc_init)
data = ifrm.data[:ifrm.bcnt] data = ifrm.data[: ifrm.bcnt]
break break
if cfrm.cid != ifrm.cid: if cfrm.cid != ifrm.cid:
# cont frame for a different channel, reply with BUSY and skip # cont frame for a different channel, reply with BUSY and skip
if __debug__: if __debug__:
log.warning(__name__, '_ERR_CHANNEL_BUSY') log.warning(__name__, "_ERR_CHANNEL_BUSY")
await send_cmd(cmd_error(cfrm.cid, _ERR_CHANNEL_BUSY), iface) await send_cmd(cmd_error(cfrm.cid, _ERR_CHANNEL_BUSY), iface)
continue continue
@ -267,7 +269,7 @@ async def read_cmd(iface: io.HID) -> Cmd:
# cont frame for this channel, but incorrect seq number, abort # cont frame for this channel, but incorrect seq number, abort
# current msg # current msg
if __debug__: if __debug__:
log.warning(__name__, '_ERR_INVALID_SEQ') log.warning(__name__, "_ERR_INVALID_SEQ")
await send_cmd(cmd_error(cfrm.cid, _ERR_INVALID_SEQ), iface) await send_cmd(cmd_error(cfrm.cid, _ERR_INVALID_SEQ), iface)
return None return None
@ -330,7 +332,6 @@ _CONFIRM_TIMEOUT_MS = const(10 * 1000)
class ConfirmState: class ConfirmState:
def __init__(self) -> None: def __init__(self) -> None:
self.reset() self.reset()
@ -382,19 +383,20 @@ class ConfirmState:
from trezor.ui.text import Text from trezor.ui.text import Text
if bytes(self.app_id) == _BOGUS_APPID: if bytes(self.app_id) == _BOGUS_APPID:
text = Text('U2F mismatch', ui.ICON_WRONG, icon_color=ui.RED) text = Text("U2F mismatch", ui.ICON_WRONG, icon_color=ui.RED)
text.normal('Another U2F device', 'was used to register', 'in this application.') text.normal(
"Another U2F device", "was used to register", "in this application."
)
text.render() text.render()
await loop.sleep(3 * 1000 * 1000) await loop.sleep(3 * 1000 * 1000)
self.confirmed = True self.confirmed = True
else: else:
content = ConfirmContent(self.action, self.app_id) content = ConfirmContent(self.action, self.app_id)
dialog = ConfirmDialog(content, ) dialog = ConfirmDialog(content)
self.confirmed = await dialog == CONFIRMED self.confirmed = await dialog == CONFIRMED
class ConfirmContent(ui.Widget): class ConfirmContent(ui.Widget):
def __init__(self, action: int, app_id: bytes) -> None: def __init__(self, action: int, app_id: bytes) -> None:
self.action = action self.action = action
self.app_id = app_id self.app_id = app_id
@ -411,25 +413,30 @@ class ConfirmContent(ui.Widget):
if app_id == _BOGUS_APPID: if app_id == _BOGUS_APPID:
# TODO: display a warning dialog for bogus app ids # TODO: display a warning dialog for bogus app ids
name = 'Another U2F device' name = "Another U2F device"
icon = res.load('apps/fido_u2f/res/u2f_generic.toif') # TODO: warning icon icon = res.load("apps/fido_u2f/res/u2f_generic.toif") # TODO: warning icon
elif app_id in knownapps.knownapps: elif app_id in knownapps.knownapps:
name = knownapps.knownapps[app_id] name = knownapps.knownapps[app_id]
try: try:
icon = res.load('apps/fido_u2f/res/u2f_%s.toif' % name.lower().replace(' ', '_')) icon = res.load(
"apps/fido_u2f/res/u2f_%s.toif" % name.lower().replace(" ", "_")
)
except Exception: except Exception:
icon = res.load('apps/fido_u2f/res/u2f_generic.toif') icon = res.load("apps/fido_u2f/res/u2f_generic.toif")
else: else:
name = '%s...%s' % (hexlify(app_id[:4]).decode(), hexlify(app_id[-4:]).decode()) name = "%s...%s" % (
icon = res.load('apps/fido_u2f/res/u2f_generic.toif') hexlify(app_id[:4]).decode(),
hexlify(app_id[-4:]).decode(),
)
icon = res.load("apps/fido_u2f/res/u2f_generic.toif")
self.app_name = name self.app_name = name
self.app_icon = icon self.app_icon = icon
def render(self) -> None: def render(self) -> None:
if self.action == _CONFIRM_REGISTER: if self.action == _CONFIRM_REGISTER:
header = 'U2F Register' header = "U2F Register"
else: else:
header = 'U2F Authenticate' header = "U2F Authenticate"
ui.header(header, ui.ICON_DEFAULT, ui.GREEN, ui.BG, ui.GREEN) ui.header(header, ui.ICON_DEFAULT, ui.GREEN, ui.BG, ui.GREEN)
ui.display.image((ui.WIDTH - 64) // 2, 64, self.app_icon) ui.display.image((ui.WIDTH - 64) // 2, 64, self.app_icon)
ui.display.text_center(ui.WIDTH // 2, 168, self.app_name, ui.MONO, ui.FG, ui.BG) ui.display.text_center(ui.WIDTH // 2, 168, self.app_name, ui.MONO, ui.FG, ui.BG)
@ -441,46 +448,46 @@ def dispatch_cmd(req: Cmd, state: ConfirmState) -> Cmd:
if m.cla != 0: if m.cla != 0:
if __debug__: if __debug__:
log.warning(__name__, '_SW_CLA_NOT_SUPPORTED') log.warning(__name__, "_SW_CLA_NOT_SUPPORTED")
return msg_error(req.cid, _SW_CLA_NOT_SUPPORTED) return msg_error(req.cid, _SW_CLA_NOT_SUPPORTED)
if m.lc + _APDU_DATA > len(req.data): if m.lc + _APDU_DATA > len(req.data):
if __debug__: if __debug__:
log.warning(__name__, '_SW_WRONG_LENGTH') log.warning(__name__, "_SW_WRONG_LENGTH")
return msg_error(req.cid, _SW_WRONG_LENGTH) return msg_error(req.cid, _SW_WRONG_LENGTH)
if m.ins == _MSG_REGISTER: if m.ins == _MSG_REGISTER:
if __debug__: if __debug__:
log.debug(__name__, '_MSG_REGISTER') log.debug(__name__, "_MSG_REGISTER")
return msg_register(m, state) return msg_register(m, state)
elif m.ins == _MSG_AUTHENTICATE: elif m.ins == _MSG_AUTHENTICATE:
if __debug__: if __debug__:
log.debug(__name__, '_MSG_AUTHENTICATE') log.debug(__name__, "_MSG_AUTHENTICATE")
return msg_authenticate(m, state) return msg_authenticate(m, state)
elif m.ins == _MSG_VERSION: elif m.ins == _MSG_VERSION:
if __debug__: if __debug__:
log.debug(__name__, '_MSG_VERSION') log.debug(__name__, "_MSG_VERSION")
return msg_version(m) return msg_version(m)
else: else:
if __debug__: if __debug__:
log.warning(__name__, '_SW_INS_NOT_SUPPORTED: %d', m.ins) log.warning(__name__, "_SW_INS_NOT_SUPPORTED: %d", m.ins)
return msg_error(req.cid, _SW_INS_NOT_SUPPORTED) return msg_error(req.cid, _SW_INS_NOT_SUPPORTED)
elif req.cmd == _CMD_INIT: elif req.cmd == _CMD_INIT:
if __debug__: if __debug__:
log.debug(__name__, '_CMD_INIT') log.debug(__name__, "_CMD_INIT")
return cmd_init(req) return cmd_init(req)
elif req.cmd == _CMD_PING: elif req.cmd == _CMD_PING:
if __debug__: if __debug__:
log.debug(__name__, '_CMD_PING') log.debug(__name__, "_CMD_PING")
return req return req
elif req.cmd == _CMD_WINK: elif req.cmd == _CMD_WINK:
if __debug__: if __debug__:
log.debug(__name__, '_CMD_WINK') log.debug(__name__, "_CMD_WINK")
return req return req
else: else:
if __debug__: if __debug__:
log.warning(__name__, '_ERR_INVALID_CMD: %d', req.cmd) log.warning(__name__, "_ERR_INVALID_CMD: %d", req.cmd)
return cmd_error(req.cid, _ERR_INVALID_CMD) return cmd_error(req.cid, _ERR_INVALID_CMD)
@ -510,13 +517,13 @@ def msg_register(req: Msg, state: ConfirmState) -> Cmd:
if not storage.is_initialized(): if not storage.is_initialized():
if __debug__: if __debug__:
log.warning(__name__, 'not initialized') log.warning(__name__, "not initialized")
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED) return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# check length of input data # check length of input data
if len(req.data) != 64: if len(req.data) != 64:
if __debug__: if __debug__:
log.warning(__name__, '_SW_WRONG_LENGTH req.data') log.warning(__name__, "_SW_WRONG_LENGTH req.data")
return msg_error(req.cid, _SW_WRONG_LENGTH) return msg_error(req.cid, _SW_WRONG_LENGTH)
# parse challenge and app_id # parse challenge and app_id
@ -532,12 +539,12 @@ def msg_register(req: Msg, state: ConfirmState) -> Cmd:
# wait for a button or continue # wait for a button or continue
if not state.confirmed: if not state.confirmed:
if __debug__: if __debug__:
log.info(__name__, 'waiting for button') log.info(__name__, "waiting for button")
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED) return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# sign the registration challenge and return # sign the registration challenge and return
if __debug__: if __debug__:
log.info(__name__, 'signing register') log.info(__name__, "signing register")
buf = msg_register_sign(chal, app_id) buf = msg_register_sign(chal, app_id)
state.reset() state.reset()
@ -554,11 +561,11 @@ def msg_register_sign(challenge: bytes, app_id: bytes) -> bytes:
nodepath = [_U2F_KEY_PATH] + keypath nodepath = [_U2F_KEY_PATH] + keypath
# prepare signing key from random path, compute decompressed public key # prepare signing key from random path, compute decompressed public key
node = seed.derive_node_without_passphrase(nodepath, 'nist256p1') node = seed.derive_node_without_passphrase(nodepath, "nist256p1")
pubkey = nist256p1.publickey(node.private_key(), False) pubkey = nist256p1.publickey(node.private_key(), False)
# first half of keyhandle is keypath # first half of keyhandle is keypath
keybuf = ustruct.pack('>8L', *keypath) keybuf = ustruct.pack(">8L", *keypath)
# second half of keyhandle is a hmac of app_id and keypath # second half of keyhandle is a hmac of app_id and keypath
keybase = hmac.Hmac(node.private_key(), app_id, hashlib.sha256) keybase = hmac.Hmac(node.private_key(), app_id, hashlib.sha256)
@ -567,12 +574,12 @@ def msg_register_sign(challenge: bytes, app_id: bytes) -> bytes:
# hash the request data together with keyhandle and pubkey # hash the request data together with keyhandle and pubkey
dig = hashlib.sha256() dig = hashlib.sha256()
dig.update(b'\x00') # uint8_t reserved; dig.update(b"\x00") # uint8_t reserved;
dig.update(app_id) # uint8_t appId[32]; dig.update(app_id) # uint8_t appId[32];
dig.update(challenge) # uint8_t chal[32]; dig.update(challenge) # uint8_t chal[32];
dig.update(keybuf) # uint8_t keyHandle[64]; dig.update(keybuf) # uint8_t keyHandle[64];
dig.update(keybase) dig.update(keybase)
dig.update(pubkey) # uint8_t pubKey[65]; dig.update(pubkey) # uint8_t pubKey[65];
dig = dig.digest() dig = dig.digest()
# sign the digest and convert to der # sign the digest and convert to der
@ -580,8 +587,9 @@ def msg_register_sign(challenge: bytes, app_id: bytes) -> bytes:
sig = der.encode_seq((sig[1:33], sig[33:])) sig = der.encode_seq((sig[1:33], sig[33:]))
# pack to a response # pack to a response
buf, resp = make_struct(resp_cmd_register( buf, resp = make_struct(
len(keybuf) + len(keybase), len(_U2F_ATT_CERT), len(sig))) resp_cmd_register(len(keybuf) + len(keybase), len(_U2F_ATT_CERT), len(sig))
)
resp.registerId = _U2F_REGISTER_ID resp.registerId = _U2F_REGISTER_ID
utils.memcpy(resp.pubKey, 0, pubkey, 0, len(pubkey)) utils.memcpy(resp.pubKey, 0, pubkey, 0, len(pubkey))
resp.keyHandleLen = len(keybuf) + len(keybase) resp.keyHandleLen = len(keybuf) + len(keybase)
@ -599,20 +607,20 @@ def msg_authenticate(req: Msg, state: ConfirmState) -> Cmd:
if not storage.is_initialized(): if not storage.is_initialized():
if __debug__: if __debug__:
log.warning(__name__, 'not initialized') log.warning(__name__, "not initialized")
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED) return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# we need at least keyHandleLen # we need at least keyHandleLen
if len(req.data) <= _REQ_CMD_AUTHENTICATE_KHLEN: if len(req.data) <= _REQ_CMD_AUTHENTICATE_KHLEN:
if __debug__: if __debug__:
log.warning(__name__, '_SW_WRONG_LENGTH req.data') log.warning(__name__, "_SW_WRONG_LENGTH req.data")
return msg_error(req.cid, _SW_WRONG_LENGTH) return msg_error(req.cid, _SW_WRONG_LENGTH)
# check keyHandleLen # check keyHandleLen
khlen = req.data[_REQ_CMD_AUTHENTICATE_KHLEN] khlen = req.data[_REQ_CMD_AUTHENTICATE_KHLEN]
if khlen != 64: if khlen != 64:
if __debug__: if __debug__:
log.warning(__name__, '_SW_WRONG_LENGTH khlen') log.warning(__name__, "_SW_WRONG_LENGTH khlen")
return msg_error(req.cid, _SW_WRONG_LENGTH) return msg_error(req.cid, _SW_WRONG_LENGTH)
auth = overlay_struct(req.data, req_cmd_authenticate(khlen)) auth = overlay_struct(req.data, req_cmd_authenticate(khlen))
@ -626,13 +634,13 @@ def msg_authenticate(req: Msg, state: ConfirmState) -> Cmd:
# if _AUTH_CHECK_ONLY is requested, return, because keyhandle has been checked already # if _AUTH_CHECK_ONLY is requested, return, because keyhandle has been checked already
if req.p1 == _AUTH_CHECK_ONLY: if req.p1 == _AUTH_CHECK_ONLY:
if __debug__: if __debug__:
log.info(__name__, '_AUTH_CHECK_ONLY') log.info(__name__, "_AUTH_CHECK_ONLY")
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED) return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# from now on, only _AUTH_ENFORCE is supported # from now on, only _AUTH_ENFORCE is supported
if req.p1 != _AUTH_ENFORCE: if req.p1 != _AUTH_ENFORCE:
if __debug__: if __debug__:
log.info(__name__, '_AUTH_ENFORCE') log.info(__name__, "_AUTH_ENFORCE")
return msg_error(req.cid, _SW_WRONG_DATA) return msg_error(req.cid, _SW_WRONG_DATA)
# check equality with last request # check equality with last request
@ -644,12 +652,12 @@ def msg_authenticate(req: Msg, state: ConfirmState) -> Cmd:
# wait for a button or continue # wait for a button or continue
if not state.confirmed: if not state.confirmed:
if __debug__: if __debug__:
log.info(__name__, 'waiting for button') log.info(__name__, "waiting for button")
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED) return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# sign the authentication challenge and return # sign the authentication challenge and return
if __debug__: if __debug__:
log.info(__name__, 'signing authentication') log.info(__name__, "signing authentication")
buf = msg_authenticate_sign(auth.chal, auth.appId, node.private_key()) buf = msg_authenticate_sign(auth.chal, auth.appId, node.private_key())
state.reset() state.reset()
@ -662,18 +670,18 @@ def msg_authenticate_genkey(app_id: bytes, keyhandle: bytes):
# unpack the keypath from the first half of keyhandle # unpack the keypath from the first half of keyhandle
keybuf = keyhandle[:32] keybuf = keyhandle[:32]
keypath = ustruct.unpack('>8L', keybuf) keypath = ustruct.unpack(">8L", keybuf)
# check high bit for hardened keys # check high bit for hardened keys
for i in keypath: for i in keypath:
if not i & HARDENED: if not i & HARDENED:
if __debug__: if __debug__:
log.warning(__name__, 'invalid key path') log.warning(__name__, "invalid key path")
return None return None
# derive the signing key # derive the signing key
nodepath = [_U2F_KEY_PATH] + list(keypath) nodepath = [_U2F_KEY_PATH] + list(keypath)
node = seed.derive_node_without_passphrase(nodepath, 'nist256p1') node = seed.derive_node_without_passphrase(nodepath, "nist256p1")
# second half of keyhandle is a hmac of app_id and keypath # second half of keyhandle is a hmac of app_id and keypath
keybase = hmac.Hmac(node.private_key(), app_id, hashlib.sha256) keybase = hmac.Hmac(node.private_key(), app_id, hashlib.sha256)
@ -683,7 +691,7 @@ def msg_authenticate_genkey(app_id: bytes, keyhandle: bytes):
# verify the hmac # verify the hmac
if keybase != keyhandle[32:]: if keybase != keyhandle[32:]:
if __debug__: if __debug__:
log.warning(__name__, 'invalid key handle') log.warning(__name__, "invalid key handle")
return None return None
return node return node
@ -694,13 +702,13 @@ def msg_authenticate_sign(challenge: bytes, app_id: bytes, privkey: bytes) -> by
# get next counter # get next counter
ctr = storage.next_u2f_counter() ctr = storage.next_u2f_counter()
ctrbuf = ustruct.pack('>L', ctr) ctrbuf = ustruct.pack(">L", ctr)
# hash input data together with counter # hash input data together with counter
dig = hashlib.sha256() dig = hashlib.sha256()
dig.update(app_id) # uint8_t appId[32]; dig.update(app_id) # uint8_t appId[32];
dig.update(flags) # uint8_t flags; dig.update(flags) # uint8_t flags;
dig.update(ctrbuf) # uint8_t ctr[4]; dig.update(ctrbuf) # uint8_t ctr[4];
dig.update(challenge) # uint8_t chal[32]; dig.update(challenge) # uint8_t chal[32];
dig = dig.digest() dig = dig.digest()
@ -721,12 +729,12 @@ def msg_authenticate_sign(challenge: bytes, app_id: bytes, privkey: bytes) -> by
def msg_version(req: Msg) -> Cmd: def msg_version(req: Msg) -> Cmd:
if req.data: if req.data:
return msg_error(req.cid, _SW_WRONG_LENGTH) return msg_error(req.cid, _SW_WRONG_LENGTH)
return Cmd(req.cid, _CMD_MSG, b'U2F_V2\x90\x00') # includes _SW_NO_ERROR return Cmd(req.cid, _CMD_MSG, b"U2F_V2\x90\x00") # includes _SW_NO_ERROR
def msg_error(cid: int, code: int) -> Cmd: def msg_error(cid: int, code: int) -> Cmd:
return Cmd(cid, _CMD_MSG, ustruct.pack('>H', code)) return Cmd(cid, _CMD_MSG, ustruct.pack(">H", code))
def cmd_error(cid: int, code: int) -> Cmd: def cmd_error(cid: int, code: int) -> Cmd:
return Cmd(cid, _CMD_ERROR, ustruct.pack('>B', code)) return Cmd(cid, _CMD_ERROR, ustruct.pack(">B", code))

@ -1,19 +1,25 @@
from trezor.crypto import hashlib from trezor.crypto import hashlib
knownapps = { knownapps = {
hashlib.sha256(b'https://account.gandi.net/api/u2f/trusted_facets.json').digest(): 'Gandi', hashlib.sha256(
hashlib.sha256(b'https://api-9dcf9b83.duosecurity.com').digest(): 'Duo', b"https://account.gandi.net/api/u2f/trusted_facets.json"
hashlib.sha256(b'https://bitbucket.org').digest(): 'Bitbucket', ).digest(): "Gandi",
hashlib.sha256(b'https://dashboard.stripe.com').digest(): 'Stripe', hashlib.sha256(b"https://api-9dcf9b83.duosecurity.com").digest(): "Duo",
hashlib.sha256(b'https://demo.yubico.com').digest(): 'Yubico U2F Demo', hashlib.sha256(b"https://bitbucket.org").digest(): "Bitbucket",
hashlib.sha256(b'https://github.com/u2f/trusted_facets').digest(): 'GitHub', hashlib.sha256(b"https://dashboard.stripe.com").digest(): "Stripe",
hashlib.sha256(b'https://gitlab.com').digest(): 'GitLab', hashlib.sha256(b"https://demo.yubico.com").digest(): "Yubico U2F Demo",
hashlib.sha256(b'https://keepersecurity.com').digest(): 'Keeper', hashlib.sha256(b"https://github.com/u2f/trusted_facets").digest(): "GitHub",
hashlib.sha256(b'https://slushpool.com/static/security/u2f.json').digest(): 'Slush Pool', hashlib.sha256(b"https://gitlab.com").digest(): "GitLab",
hashlib.sha256(b'https://u2f.bin.coffee').digest(): 'u2f.bin.coffee checker', hashlib.sha256(b"https://keepersecurity.com").digest(): "Keeper",
hashlib.sha256(b'https://vault.bitwarden.com/app-id.json').digest(): 'bitwarden', hashlib.sha256(
hashlib.sha256(b'https://www.bitfinex.com').digest(): 'Bitfinex', b"https://slushpool.com/static/security/u2f.json"
hashlib.sha256(b'https://www.dropbox.com/u2f-app-id.json').digest(): 'Dropbox', ).digest(): "Slush Pool",
hashlib.sha256(b'https://www.fastmail.com').digest(): 'FastMail', hashlib.sha256(b"https://u2f.bin.coffee").digest(): "u2f.bin.coffee checker",
hashlib.sha256(b'https://www.gstatic.com/securitykey/origins.json').digest(): 'Google', hashlib.sha256(b"https://vault.bitwarden.com/app-id.json").digest(): "bitwarden",
hashlib.sha256(b"https://www.bitfinex.com").digest(): "Bitfinex",
hashlib.sha256(b"https://www.dropbox.com/u2f-app-id.json").digest(): "Dropbox",
hashlib.sha256(b"https://www.fastmail.com").digest(): "FastMail",
hashlib.sha256(
b"https://www.gstatic.com/securitykey/origins.json"
).digest(): "Google",
} }

@ -9,15 +9,15 @@ from apps.common import cache, storage
def get_features(): def get_features():
f = Features() f = Features()
f.vendor = 'trezor.io' f.vendor = "trezor.io"
f.language = 'english' f.language = "english"
f.major_version = utils.symbol('VERSION_MAJOR') f.major_version = utils.symbol("VERSION_MAJOR")
f.minor_version = utils.symbol('VERSION_MINOR') f.minor_version = utils.symbol("VERSION_MINOR")
f.patch_version = utils.symbol('VERSION_PATCH') f.patch_version = utils.symbol("VERSION_PATCH")
f.revision = utils.symbol('GITREV') f.revision = utils.symbol("GITREV")
f.model = utils.model() f.model = utils.model()
if f.model == 'EMU': if f.model == "EMU":
f.model = 'T' # emulator currently emulates model T f.model = "T" # emulator currently emulates model T
f.device_id = storage.get_device_id() f.device_id = storage.get_device_id()
f.label = storage.get_label() f.label = storage.get_label()
f.initialized = storage.is_initialized() f.initialized = storage.is_initialized()
@ -42,12 +42,12 @@ async def handle_GetFeatures(ctx, msg):
async def handle_Cancel(ctx, msg): async def handle_Cancel(ctx, msg):
raise wire.ActionCancelled('Cancelled') raise wire.ActionCancelled("Cancelled")
async def handle_ClearSession(ctx, msg): async def handle_ClearSession(ctx, msg):
cache.clear() cache.clear()
return Success(message='Session cleared') return Success(message="Session cleared")
async def handle_Ping(ctx, msg): async def handle_Ping(ctx, msg):
@ -55,9 +55,11 @@ async def handle_Ping(ctx, msg):
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from trezor.messages.ButtonRequestType import ProtectCall from trezor.messages.ButtonRequestType import ProtectCall
from trezor.ui.text import Text from trezor.ui.text import Text
await require_confirm(ctx, Text('Confirm'), ProtectCall)
await require_confirm(ctx, Text("Confirm"), ProtectCall)
if msg.passphrase_protection: if msg.passphrase_protection:
from apps.common.request_passphrase import protect_by_passphrase from apps.common.request_passphrase import protect_by_passphrase
await protect_by_passphrase(ctx) await protect_by_passphrase(ctx)
return Success(message=msg.message) return Success(message=msg.message)

@ -14,26 +14,32 @@ async def homescreen():
def display_homescreen(): def display_homescreen():
if not storage.is_initialized(): if not storage.is_initialized():
label = 'Go to trezor.io/start' label = "Go to trezor.io/start"
image = None image = None
else: else:
label = storage.get_label() or 'My TREZOR' label = storage.get_label() or "My TREZOR"
image = storage.get_homescreen() image = storage.get_homescreen()
if not image: if not image:
image = res.load('apps/homescreen/res/bg.toif') image = res.load("apps/homescreen/res/bg.toif")
if storage.is_initialized() and storage.unfinished_backup(): if storage.is_initialized() and storage.unfinished_backup():
ui.display.bar(0, 0, ui.WIDTH, 30, ui.RED) ui.display.bar(0, 0, ui.WIDTH, 30, ui.RED)
ui.display.text_center(ui.WIDTH // 2, 22, 'BACKUP FAILED!', ui.BOLD, ui.WHITE, ui.RED) ui.display.text_center(
ui.WIDTH // 2, 22, "BACKUP FAILED!", ui.BOLD, ui.WHITE, ui.RED
)
ui.display.bar(0, 30, ui.WIDTH, ui.HEIGHT - 30, ui.BG) ui.display.bar(0, 30, ui.WIDTH, ui.HEIGHT - 30, ui.BG)
elif storage.is_initialized() and storage.needs_backup(): elif storage.is_initialized() and storage.needs_backup():
ui.display.bar(0, 0, ui.WIDTH, 30, ui.YELLOW) ui.display.bar(0, 0, ui.WIDTH, 30, ui.YELLOW)
ui.display.text_center(ui.WIDTH // 2, 22, 'NEEDS BACKUP!', ui.BOLD, ui.BLACK, ui.YELLOW) ui.display.text_center(
ui.WIDTH // 2, 22, "NEEDS BACKUP!", ui.BOLD, ui.BLACK, ui.YELLOW
)
ui.display.bar(0, 30, ui.WIDTH, ui.HEIGHT - 30, ui.BG) ui.display.bar(0, 30, ui.WIDTH, ui.HEIGHT - 30, ui.BG)
elif storage.is_initialized() and not config.has_pin(): elif storage.is_initialized() and not config.has_pin():
ui.display.bar(0, 0, ui.WIDTH, 30, ui.YELLOW) ui.display.bar(0, 0, ui.WIDTH, 30, ui.YELLOW)
ui.display.text_center(ui.WIDTH // 2, 22, 'PIN NOT SET!', ui.BOLD, ui.BLACK, ui.YELLOW) ui.display.text_center(
ui.WIDTH // 2, 22, "PIN NOT SET!", ui.BOLD, ui.BLACK, ui.YELLOW
)
ui.display.bar(0, 30, ui.WIDTH, ui.HEIGHT - 30, ui.BG) ui.display.bar(0, 30, ui.WIDTH, ui.HEIGHT - 30, ui.BG)
else: else:
ui.display.bar(0, 0, ui.WIDTH, ui.HEIGHT, ui.BG) ui.display.bar(0, 0, ui.WIDTH, ui.HEIGHT, ui.BG)

@ -10,26 +10,31 @@ from trezor.wire import protobuf_workflow, register
def dispatch_LiskGetAddress(*args, **kwargs): def dispatch_LiskGetAddress(*args, **kwargs):
from .get_address import layout_lisk_get_address from .get_address import layout_lisk_get_address
return layout_lisk_get_address(*args, **kwargs) return layout_lisk_get_address(*args, **kwargs)
def dispatch_LiskGetPublicKey(*args, **kwargs): def dispatch_LiskGetPublicKey(*args, **kwargs):
from .get_public_key import lisk_get_public_key from .get_public_key import lisk_get_public_key
return lisk_get_public_key(*args, **kwargs) return lisk_get_public_key(*args, **kwargs)
def dispatch_LiskSignTx(*args, **kwargs): def dispatch_LiskSignTx(*args, **kwargs):
from .sign_tx import lisk_sign_tx from .sign_tx import lisk_sign_tx
return lisk_sign_tx(*args, **kwargs) return lisk_sign_tx(*args, **kwargs)
def dispatch_LiskSignMessage(*args, **kwargs): def dispatch_LiskSignMessage(*args, **kwargs):
from .sign_message import lisk_sign_message from .sign_message import lisk_sign_message
return lisk_sign_message(*args, **kwargs) return lisk_sign_message(*args, **kwargs)
def dispatch_LiskVerifyMessage(*args, **kwargs): def dispatch_LiskVerifyMessage(*args, **kwargs):
from .verify_message import lisk_verify_message from .verify_message import lisk_verify_message
return lisk_verify_message(*args, **kwargs) return lisk_verify_message(*args, **kwargs)

@ -1,18 +1,18 @@
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
LISK_CURVE = 'ed25519' LISK_CURVE = "ed25519"
def get_address_from_public_key(pubkey): def get_address_from_public_key(pubkey):
pubkeyhash = sha256(pubkey).digest() pubkeyhash = sha256(pubkey).digest()
address = int.from_bytes(pubkeyhash[:8], 'little') address = int.from_bytes(pubkeyhash[:8], "little")
return str(address) + 'L' return str(address) + "L"
def get_votes_count(votes): def get_votes_count(votes):
plus, minus = 0, 0 plus, minus = 0, 0
for vote in votes: for vote in votes:
if vote.startswith('+'): if vote.startswith("+"):
plus += 1 plus += 1
else: else:
minus += 1 minus += 1
@ -23,11 +23,11 @@ def get_vote_tx_text(votes):
plus, minus = get_votes_count(votes) plus, minus = get_votes_count(votes)
text = [] text = []
if plus > 0: if plus > 0:
text.append(_text_with_plural('Add', plus)) text.append(_text_with_plural("Add", plus))
if minus > 0: if minus > 0:
text.append(_text_with_plural('Remove', minus)) text.append(_text_with_plural("Remove", minus))
return text return text
def _text_with_plural(txt, value): def _text_with_plural(txt, value):
return '%s %s %s' % (txt, value, ('votes' if value != 1 else 'vote')) return "%s %s %s" % (txt, value, ("votes" if value != 1 else "vote"))

@ -10,23 +10,23 @@ from apps.wallet.get_public_key import _show_pubkey
async def require_confirm_tx(ctx, to, value): async def require_confirm_tx(ctx, to, value):
text = Text('Confirm sending', ui.ICON_SEND, icon_color=ui.GREEN) text = Text("Confirm sending", ui.ICON_SEND, icon_color=ui.GREEN)
text.bold(format_amount(value)) text.bold(format_amount(value))
text.normal('to') text.normal("to")
text.mono(*split_address(to)) text.mono(*split_address(to))
return await require_confirm(ctx, text, ButtonRequestType.SignTx) return await require_confirm(ctx, text, ButtonRequestType.SignTx)
async def require_confirm_delegate_registration(ctx, delegate_name): async def require_confirm_delegate_registration(ctx, delegate_name):
text = Text('Confirm transaction', ui.ICON_SEND, icon_color=ui.GREEN) text = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
text.normal('Do you really want to') text.normal("Do you really want to")
text.normal('register a delegate?') text.normal("register a delegate?")
text.bold(*chunks(delegate_name, 20)) text.bold(*chunks(delegate_name, 20))
return await require_confirm(ctx, text, ButtonRequestType.SignTx) return await require_confirm(ctx, text, ButtonRequestType.SignTx)
async def require_confirm_vote_tx(ctx, votes): async def require_confirm_vote_tx(ctx, votes):
text = Text('Confirm transaction', ui.ICON_SEND, icon_color=ui.GREEN) text = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
text.normal(*get_vote_tx_text(votes)) text.normal(*get_vote_tx_text(votes))
return await require_confirm(ctx, text, ButtonRequestType.SignTx) return await require_confirm(ctx, text, ButtonRequestType.SignTx)
@ -36,23 +36,23 @@ async def require_confirm_public_key(ctx, public_key):
async def require_confirm_multisig(ctx, multisignature): async def require_confirm_multisig(ctx, multisignature):
text = Text('Confirm transaction', ui.ICON_SEND, icon_color=ui.GREEN) text = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
text.normal('Keys group length: %s' % len(multisignature.keys_group)) text.normal("Keys group length: %s" % len(multisignature.keys_group))
text.normal('Life time: %s' % multisignature.life_time) text.normal("Life time: %s" % multisignature.life_time)
text.normal('Min: %s' % multisignature.min) text.normal("Min: %s" % multisignature.min)
return await require_confirm(ctx, text, ButtonRequestType.SignTx) return await require_confirm(ctx, text, ButtonRequestType.SignTx)
async def require_confirm_fee(ctx, value, fee): async def require_confirm_fee(ctx, value, fee):
text = Text('Confirm transaction', ui.ICON_SEND, icon_color=ui.GREEN) text = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
text.bold(format_amount(value)) text.bold(format_amount(value))
text.normal('fee:') text.normal("fee:")
text.bold(format_amount(fee)) text.bold(format_amount(fee))
await require_hold_to_confirm(ctx, text, ButtonRequestType.ConfirmOutput) await require_hold_to_confirm(ctx, text, ButtonRequestType.ConfirmOutput)
def format_amount(value): def format_amount(value):
return '%s LSK' % (int(value) / 100000000) return "%s LSK" % (int(value) / 100000000)
def split_address(address): def split_address(address):

@ -14,7 +14,7 @@ from apps.wallet.sign_tx.signing import write_varint
def message_digest(message): def message_digest(message):
h = HashWriter(sha256) h = HashWriter(sha256)
signed_message_header = 'Lisk Signed Message:\n' signed_message_header = "Lisk Signed Message:\n"
write_varint(h, len(signed_message_header)) write_varint(h, len(signed_message_header))
h.extend(signed_message_header) h.extend(signed_message_header)
write_varint(h, len(message)) write_varint(h, len(message))
@ -39,6 +39,6 @@ async def lisk_sign_message(ctx, msg):
async def require_confirm_sign_message(ctx, message): async def require_confirm_sign_message(ctx, message):
text = Text('Sign Lisk message') text = Text("Sign Lisk message")
text.normal(*split_message(message)) text.normal(*split_message(message))
await require_confirm(ctx, text) await require_confirm(ctx, text)

@ -20,7 +20,7 @@ async def lisk_sign_tx(ctx, msg):
try: try:
await _require_confirm_by_type(ctx, transaction) await _require_confirm_by_type(ctx, transaction)
except AttributeError: except AttributeError:
raise wire.DataError('The transaction has invalid asset data field') raise wire.DataError("The transaction has invalid asset data field")
await layout.require_confirm_fee(ctx, transaction.amount, transaction.fee) await layout.require_confirm_fee(ctx, transaction.amount, transaction.fee)
@ -61,48 +61,59 @@ async def _require_confirm_by_type(ctx, transaction):
if transaction.type == LiskTransactionType.Transfer: if transaction.type == LiskTransactionType.Transfer:
return await layout.require_confirm_tx( return await layout.require_confirm_tx(
ctx, transaction.recipient_id, transaction.amount) ctx, transaction.recipient_id, transaction.amount
)
if transaction.type == LiskTransactionType.RegisterDelegate: if transaction.type == LiskTransactionType.RegisterDelegate:
return await layout.require_confirm_delegate_registration( return await layout.require_confirm_delegate_registration(
ctx, transaction.asset.delegate.username) ctx, transaction.asset.delegate.username
)
if transaction.type == LiskTransactionType.CastVotes: if transaction.type == LiskTransactionType.CastVotes:
return await layout.require_confirm_vote_tx( return await layout.require_confirm_vote_tx(ctx, transaction.asset.votes)
ctx, transaction.asset.votes)
if transaction.type == LiskTransactionType.RegisterSecondPassphrase: if transaction.type == LiskTransactionType.RegisterSecondPassphrase:
return await layout.require_confirm_public_key( return await layout.require_confirm_public_key(
ctx, transaction.asset.signature.public_key) ctx, transaction.asset.signature.public_key
)
if transaction.type == LiskTransactionType.RegisterMultisignatureAccount: if transaction.type == LiskTransactionType.RegisterMultisignatureAccount:
return await layout.require_confirm_multisig( return await layout.require_confirm_multisig(
ctx, transaction.asset.multisignature) ctx, transaction.asset.multisignature
)
raise wire.DataError('Invalid transaction type') raise wire.DataError("Invalid transaction type")
def _get_transaction_bytes(tx): def _get_transaction_bytes(tx):
# Required transaction parameters # Required transaction parameters
t_type = ustruct.pack('<b', tx.type) t_type = ustruct.pack("<b", tx.type)
t_timestamp = ustruct.pack('<i', tx.timestamp) t_timestamp = ustruct.pack("<i", tx.timestamp)
t_sender_public_key = tx.sender_public_key t_sender_public_key = tx.sender_public_key
t_requester_public_key = tx.requester_public_key or b'' t_requester_public_key = tx.requester_public_key or b""
if not tx.recipient_id: if not tx.recipient_id:
# Value can be empty string # Value can be empty string
t_recipient_id = ustruct.pack('>Q', 0) t_recipient_id = ustruct.pack(">Q", 0)
else: else:
# Lisk uses big-endian for recipient_id, string -> int -> bytes # Lisk uses big-endian for recipient_id, string -> int -> bytes
t_recipient_id = ustruct.pack('>Q', int(tx.recipient_id[:-1])) t_recipient_id = ustruct.pack(">Q", int(tx.recipient_id[:-1]))
t_amount = ustruct.pack('<Q', tx.amount) t_amount = ustruct.pack("<Q", tx.amount)
t_asset = _get_asset_data_bytes(tx) t_asset = _get_asset_data_bytes(tx)
t_signature = tx.signature or b'' t_signature = tx.signature or b""
return (t_type, t_timestamp, t_sender_public_key, t_requester_public_key, return (
t_recipient_id, t_amount, t_asset, t_signature) t_type,
t_timestamp,
t_sender_public_key,
t_requester_public_key,
t_recipient_id,
t_amount,
t_asset,
t_signature,
)
def _get_asset_data_bytes(msg): def _get_asset_data_bytes(msg):
@ -110,24 +121,24 @@ def _get_asset_data_bytes(msg):
if msg.type == LiskTransactionType.Transfer: if msg.type == LiskTransactionType.Transfer:
# Transfer transaction have optional data field # Transfer transaction have optional data field
if msg.asset.data is not None: if msg.asset.data is not None:
return bytes(msg.asset.data, 'utf8') return bytes(msg.asset.data, "utf8")
else: else:
return b'' return b""
if msg.type == LiskTransactionType.RegisterDelegate: if msg.type == LiskTransactionType.RegisterDelegate:
return bytes(msg.asset.delegate.username, 'utf8') return bytes(msg.asset.delegate.username, "utf8")
if msg.type == LiskTransactionType.CastVotes: if msg.type == LiskTransactionType.CastVotes:
return bytes(''.join(msg.asset.votes), 'utf8') return bytes("".join(msg.asset.votes), "utf8")
if msg.type == LiskTransactionType.RegisterSecondPassphrase: if msg.type == LiskTransactionType.RegisterSecondPassphrase:
return msg.asset.signature.public_key return msg.asset.signature.public_key
if msg.type == LiskTransactionType.RegisterMultisignatureAccount: if msg.type == LiskTransactionType.RegisterMultisignatureAccount:
data = b'' data = b""
data += ustruct.pack('<b', msg.asset.multisignature.min) data += ustruct.pack("<b", msg.asset.multisignature.min)
data += ustruct.pack('<b', msg.asset.multisignature.life_time) data += ustruct.pack("<b", msg.asset.multisignature.life_time)
data += bytes(''.join(msg.asset.multisignature.keys_group), 'utf8') data += bytes("".join(msg.asset.multisignature.keys_group), "utf8")
return data return data
raise wire.DataError('Invalid transaction type') raise wire.DataError("Invalid transaction type")

@ -12,9 +12,9 @@ async def lisk_verify_message(ctx, msg):
digest = message_digest(msg.message) digest = message_digest(msg.message)
verified = ed25519.verify(msg.public_key, msg.signature, digest) verified = ed25519.verify(msg.public_key, msg.signature, digest)
if not verified: if not verified:
raise wire.ProcessError('Invalid signature') raise wire.ProcessError("Invalid signature")
address = get_address_from_public_key(msg.public_key) address = get_address_from_public_key(msg.public_key)
await require_confirm_verify_message(ctx, address, msg.message) await require_confirm_verify_message(ctx, address, msg.message)
return Success(message='Message verified') return Success(message="Message verified")

@ -14,46 +14,55 @@ from trezor.wire import protobuf_workflow, register
def dispatch_LoadDevice(*args, **kwargs): def dispatch_LoadDevice(*args, **kwargs):
from .load_device import load_device from .load_device import load_device
return load_device(*args, **kwargs) return load_device(*args, **kwargs)
def dispatch_ResetDevice(*args, **kwargs): def dispatch_ResetDevice(*args, **kwargs):
from .reset_device import reset_device from .reset_device import reset_device
return reset_device(*args, **kwargs) return reset_device(*args, **kwargs)
def dispatch_BackupDevice(*args, **kwargs): def dispatch_BackupDevice(*args, **kwargs):
from .backup_device import backup_device from .backup_device import backup_device
return backup_device(*args, **kwargs) return backup_device(*args, **kwargs)
def dispatch_WipeDevice(*args, **kwargs): def dispatch_WipeDevice(*args, **kwargs):
from .wipe_device import wipe_device from .wipe_device import wipe_device
return wipe_device(*args, **kwargs) return wipe_device(*args, **kwargs)
def dispatch_RecoveryDevice(*args, **kwargs): def dispatch_RecoveryDevice(*args, **kwargs):
from .recovery_device import recovery_device from .recovery_device import recovery_device
return recovery_device(*args, **kwargs) return recovery_device(*args, **kwargs)
def dispatch_ApplySettings(*args, **kwargs): def dispatch_ApplySettings(*args, **kwargs):
from .apply_settings import apply_settings from .apply_settings import apply_settings
return apply_settings(*args, **kwargs) return apply_settings(*args, **kwargs)
def dispatch_ApplyFlags(*args, **kwargs): def dispatch_ApplyFlags(*args, **kwargs):
from .apply_flags import apply_flags from .apply_flags import apply_flags
return apply_flags(*args, **kwargs) return apply_flags(*args, **kwargs)
def dispatch_ChangePin(*args, **kwargs): def dispatch_ChangePin(*args, **kwargs):
from .change_pin import change_pin from .change_pin import change_pin
return change_pin(*args, **kwargs) return change_pin(*args, **kwargs)
def dispatch_SetU2FCounter(*args, **kwargs): def dispatch_SetU2FCounter(*args, **kwargs):
from .set_u2f_counter import set_u2f_counter from .set_u2f_counter import set_u2f_counter
return set_u2f_counter(*args, **kwargs) return set_u2f_counter(*args, **kwargs)

@ -5,4 +5,4 @@ from apps.common import storage
async def apply_flags(ctx, msg): async def apply_flags(ctx, msg):
storage.set_flags(msg.flags) storage.set_flags(msg.flags)
return Success(message='Flags applied') return Success(message="Flags applied")

@ -8,12 +8,17 @@ from apps.common.confirm import require_confirm
async def apply_settings(ctx, msg): async def apply_settings(ctx, msg):
if msg.homescreen is None and msg.label is None and msg.use_passphrase is None and msg.passphrase_source is None: if (
raise wire.ProcessError('No setting provided') msg.homescreen is None
and msg.label is None
and msg.use_passphrase is None
and msg.passphrase_source is None
):
raise wire.ProcessError("No setting provided")
if msg.homescreen is not None: if msg.homescreen is not None:
if len(msg.homescreen) > storage.HOMESCREEN_MAXSIZE: if len(msg.homescreen) > storage.HOMESCREEN_MAXSIZE:
raise wire.DataError('Homescreen is too complex') raise wire.DataError("Homescreen is too complex")
await require_confirm_change_homescreen(ctx) await require_confirm_change_homescreen(ctx)
if msg.label is not None: if msg.label is not None:
@ -25,43 +30,45 @@ async def apply_settings(ctx, msg):
if msg.passphrase_source is not None: if msg.passphrase_source is not None:
await require_confirm_change_passphrase_source(ctx, msg.passphrase_source) await require_confirm_change_passphrase_source(ctx, msg.passphrase_source)
storage.load_settings(label=msg.label, storage.load_settings(
use_passphrase=msg.use_passphrase, label=msg.label,
homescreen=msg.homescreen, use_passphrase=msg.use_passphrase,
passphrase_source=msg.passphrase_source) homescreen=msg.homescreen,
passphrase_source=msg.passphrase_source,
)
return Success(message='Settings applied') return Success(message="Settings applied")
async def require_confirm_change_homescreen(ctx): async def require_confirm_change_homescreen(ctx):
text = Text('Change homescreen', ui.ICON_CONFIG) text = Text("Change homescreen", ui.ICON_CONFIG)
text.normal('Do you really want to', 'change homescreen?') text.normal("Do you really want to", "change homescreen?")
await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall) await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall)
async def require_confirm_change_label(ctx, label): async def require_confirm_change_label(ctx, label):
text = Text('Change label', ui.ICON_CONFIG) text = Text("Change label", ui.ICON_CONFIG)
text.normal('Do you really want to', 'change label to') text.normal("Do you really want to", "change label to")
text.bold('%s?' % label) text.bold("%s?" % label)
await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall) await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall)
async def require_confirm_change_passphrase(ctx, use): async def require_confirm_change_passphrase(ctx, use):
text = Text('Enable passphrase' if use else 'Disable passphrase', ui.ICON_CONFIG) text = Text("Enable passphrase" if use else "Disable passphrase", ui.ICON_CONFIG)
text.normal('Do you really want to') text.normal("Do you really want to")
text.normal('enable passphrase' if use else 'disable passphrase') text.normal("enable passphrase" if use else "disable passphrase")
text.normal('encryption?') text.normal("encryption?")
await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall) await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall)
async def require_confirm_change_passphrase_source(ctx, source): async def require_confirm_change_passphrase_source(ctx, source):
if source == PassphraseSourceType.DEVICE: if source == PassphraseSourceType.DEVICE:
desc = 'ON DEVICE' desc = "ON DEVICE"
elif source == PassphraseSourceType.HOST: elif source == PassphraseSourceType.HOST:
desc = 'ON HOST' desc = "ON HOST"
else: else:
desc = 'ASK' desc = "ASK"
text = Text('Passphrase source', ui.ICON_CONFIG) text = Text("Passphrase source", ui.ICON_CONFIG)
text.normal('Do you really want to', 'change the passphrase', 'source to') text.normal("Do you really want to", "change the passphrase", "source to")
text.bold('ALWAYS %s?' % desc) text.bold("ALWAYS %s?" % desc)
await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall) await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall)

@ -12,9 +12,9 @@ from apps.management.reset_device import (
async def backup_device(ctx, msg): async def backup_device(ctx, msg):
if not storage.is_initialized(): if not storage.is_initialized():
raise wire.ProcessError('Device is not initialized') raise wire.ProcessError("Device is not initialized")
if not storage.needs_backup(): if not storage.needs_backup():
raise wire.ProcessError('Seed already backed up') raise wire.ProcessError("Seed already backed up")
mnemonic = storage.get_mnemonic() mnemonic = storage.get_mnemonic()
@ -33,4 +33,4 @@ async def backup_device(ctx, msg):
storage.set_unfinished_backup(False) storage.set_unfinished_backup(False)
return Success(message='Seed successfully backed up') return Success(message="Seed successfully backed up")

@ -18,52 +18,52 @@ async def change_pin(ctx, msg):
if config.has_pin(): if config.has_pin():
curpin = await request_pin_ack(ctx) curpin = await request_pin_ack(ctx)
if not config.check_pin(pin_to_int(curpin), show_pin_timeout): if not config.check_pin(pin_to_int(curpin), show_pin_timeout):
raise wire.PinInvalid('PIN invalid') raise wire.PinInvalid("PIN invalid")
else: else:
curpin = '' curpin = ""
# get new pin # get new pin
if not msg.remove: if not msg.remove:
newpin = await request_pin_confirm(ctx) newpin = await request_pin_confirm(ctx)
else: else:
newpin = '' newpin = ""
# write into storage # write into storage
if not config.change_pin(pin_to_int(curpin), pin_to_int(newpin), show_pin_timeout): if not config.change_pin(pin_to_int(curpin), pin_to_int(newpin), show_pin_timeout):
raise wire.PinInvalid('PIN invalid') raise wire.PinInvalid("PIN invalid")
if newpin: if newpin:
return Success(message='PIN changed') return Success(message="PIN changed")
else: else:
return Success(message='PIN removed') return Success(message="PIN removed")
def require_confirm_change_pin(ctx, msg): def require_confirm_change_pin(ctx, msg):
has_pin = config.has_pin() has_pin = config.has_pin()
if msg.remove and has_pin: # removing pin if msg.remove and has_pin: # removing pin
text = Text('Remove PIN', ui.ICON_CONFIG) text = Text("Remove PIN", ui.ICON_CONFIG)
text.normal('Do you really want to') text.normal("Do you really want to")
text.bold('remove current PIN?') text.bold("remove current PIN?")
return require_confirm(ctx, text) return require_confirm(ctx, text)
if not msg.remove and has_pin: # changing pin if not msg.remove and has_pin: # changing pin
text = Text('Remove PIN', ui.ICON_CONFIG) text = Text("Remove PIN", ui.ICON_CONFIG)
text.normal('Do you really want to') text.normal("Do you really want to")
text.bold('change current PIN?') text.bold("change current PIN?")
return require_confirm(ctx, text) return require_confirm(ctx, text)
if not msg.remove and not has_pin: # setting new pin if not msg.remove and not has_pin: # setting new pin
text = Text('Remove PIN', ui.ICON_CONFIG) text = Text("Remove PIN", ui.ICON_CONFIG)
text.normal('Do you really want to') text.normal("Do you really want to")
text.bold('set new PIN?') text.bold("set new PIN?")
return require_confirm(ctx, text) return require_confirm(ctx, text)
async def request_pin_confirm(ctx, *args, **kwargs): async def request_pin_confirm(ctx, *args, **kwargs):
while True: while True:
pin1 = await request_pin_ack(ctx, 'Enter new PIN', *args, **kwargs) pin1 = await request_pin_ack(ctx, "Enter new PIN", *args, **kwargs)
pin2 = await request_pin_ack(ctx, 'Re-enter new PIN', *args, **kwargs) pin2 = await request_pin_ack(ctx, "Re-enter new PIN", *args, **kwargs)
if pin1 == pin2: if pin1 == pin2:
return pin1 return pin1
await pin_mismatch() await pin_mismatch()
@ -71,17 +71,19 @@ async def request_pin_confirm(ctx, *args, **kwargs):
async def request_pin_ack(ctx, *args, **kwargs): async def request_pin_ack(ctx, *args, **kwargs):
try: try:
await ctx.call(ButtonRequest(code=ButtonRequestType.Other), MessageType.ButtonAck) await ctx.call(
ButtonRequest(code=ButtonRequestType.Other), MessageType.ButtonAck
)
return await ctx.wait(request_pin(*args, **kwargs)) return await ctx.wait(request_pin(*args, **kwargs))
except PinCancelled: except PinCancelled:
raise wire.ActionCancelled('Cancelled') raise wire.ActionCancelled("Cancelled")
@ui.layout @ui.layout
async def pin_mismatch(): async def pin_mismatch():
text = Text('PIN mismatch', ui.ICON_WRONG, icon_color=ui.RED) text = Text("PIN mismatch", ui.ICON_WRONG, icon_color=ui.RED)
text.normal('Entered PINs do not', 'match each other.') text.normal("Entered PINs do not", "match each other.")
text.normal('') text.normal("")
text.normal('Please, try again...') text.normal("Please, try again...")
text.render() text.render()
await loop.sleep(3 * 1000 * 1000) await loop.sleep(3 * 1000 * 1000)

@ -11,24 +11,22 @@ from apps.common.confirm import require_confirm
async def load_device(ctx, msg): async def load_device(ctx, msg):
if storage.is_initialized(): if storage.is_initialized():
raise wire.UnexpectedMessage('Already initialized') raise wire.UnexpectedMessage("Already initialized")
if msg.node is not None: if msg.node is not None:
raise wire.ProcessError('LoadDevice.node is not supported') raise wire.ProcessError("LoadDevice.node is not supported")
if not msg.skip_checksum and not bip39.check(msg.mnemonic): if not msg.skip_checksum and not bip39.check(msg.mnemonic):
raise wire.ProcessError('Mnemonic is not valid') raise wire.ProcessError("Mnemonic is not valid")
text = Text('Loading seed') text = Text("Loading seed")
text.bold('Loading private seed', 'is not recommended.') text.bold("Loading private seed", "is not recommended.")
text.normal('Continue only if you', 'know what you are doing!') text.normal("Continue only if you", "know what you are doing!")
await require_confirm(ctx, text) await require_confirm(ctx, text)
storage.load_mnemonic( storage.load_mnemonic(mnemonic=msg.mnemonic, needs_backup=True)
mnemonic=msg.mnemonic, needs_backup=True) storage.load_settings(use_passphrase=msg.passphrase_protection, label=msg.label)
storage.load_settings(
use_passphrase=msg.passphrase_protection, label=msg.label)
if msg.pin: if msg.pin:
config.change_pin(pin_to_int(''), pin_to_int(msg.pin), None) config.change_pin(pin_to_int(""), pin_to_int(msg.pin), None)
return Success(message='Device loaded') return Success(message="Device loaded")

@ -15,7 +15,7 @@ from apps.management.change_pin import request_pin_confirm
async def recovery_device(ctx, msg): async def recovery_device(ctx, msg):
''' """
Recover BIP39 seed into empty device. Recover BIP39 seed into empty device.
1. Ask for the number of words in recovered seed. 1. Ask for the number of words in recovered seed.
@ -23,9 +23,9 @@ async def recovery_device(ctx, msg):
3. Optionally check the seed validity. 3. Optionally check the seed validity.
4. Optionally ask for the PIN, with confirmation. 4. Optionally ask for the PIN, with confirmation.
5. Save into storage. 5. Save into storage.
''' """
if not msg.dry_run and storage.is_initialized(): if not msg.dry_run and storage.is_initialized():
raise wire.UnexpectedMessage('Already initialized') raise wire.UnexpectedMessage("Already initialized")
# ask for the number of words # ask for the number of words
wordcount = await request_wordcount(ctx) wordcount = await request_wordcount(ctx)
@ -36,7 +36,7 @@ async def recovery_device(ctx, msg):
# check mnemonic validity # check mnemonic validity
if msg.enforce_wordlist or msg.dry_run: if msg.enforce_wordlist or msg.dry_run:
if not bip39.check(mnemonic): if not bip39.check(mnemonic):
raise wire.ProcessError('Mnemonic is not valid') raise wire.ProcessError("Mnemonic is not valid")
# ask for pin repeatedly # ask for pin repeatedly
if msg.pin_protection: if msg.pin_protection:
@ -45,25 +45,27 @@ async def recovery_device(ctx, msg):
# save into storage # save into storage
if not msg.dry_run: if not msg.dry_run:
if msg.pin_protection: if msg.pin_protection:
config.change_pin(pin_to_int(''), pin_to_int(newpin), None) config.change_pin(pin_to_int(""), pin_to_int(newpin), None)
storage.load_settings( storage.load_settings(label=msg.label, use_passphrase=msg.passphrase_protection)
label=msg.label, use_passphrase=msg.passphrase_protection) storage.load_mnemonic(mnemonic=mnemonic, needs_backup=False)
storage.load_mnemonic( return Success(message="Device recovered")
mnemonic=mnemonic, needs_backup=False)
return Success(message='Device recovered')
else: else:
if storage.get_mnemonic() == mnemonic: if storage.get_mnemonic() == mnemonic:
return Success(message='The seed is valid and matches the one in the device') return Success(
message="The seed is valid and matches the one in the device"
)
else: else:
raise wire.ProcessError('The seed is valid but does not match the one in the device') raise wire.ProcessError(
"The seed is valid but does not match the one in the device"
)
@ui.layout @ui.layout
async def request_wordcount(ctx): async def request_wordcount(ctx):
await ctx.call(ButtonRequest(code=MnemonicWordCount), ButtonAck) await ctx.call(ButtonRequest(code=MnemonicWordCount), ButtonAck)
text = Text('Device recovery', ui.ICON_RECOVERY) text = Text("Device recovery", ui.ICON_RECOVERY)
text.normal('Number of words?') text.normal("Number of words?")
count = await ctx.wait(WordSelector(text)) count = await ctx.wait(WordSelector(text))
return count return count
@ -76,8 +78,8 @@ async def request_mnemonic(ctx, count: int) -> str:
words = [] words = []
board = MnemonicKeyboard() board = MnemonicKeyboard()
for i in range(count): for i in range(count):
board.prompt = 'Type the %s word:' % format_ordinal(i + 1) board.prompt = "Type the %s word:" % format_ordinal(i + 1)
word = await ctx.wait(board) word = await ctx.wait(board)
words.append(word) words.append(word)
return ' '.join(words) return " ".join(words)

@ -25,15 +25,15 @@ if __debug__:
async def reset_device(ctx, msg): async def reset_device(ctx, msg):
# validate parameters and device state # validate parameters and device state
if msg.strength not in (128, 192, 256): if msg.strength not in (128, 192, 256):
raise wire.ProcessError('Invalid strength (has to be 128, 192 or 256 bits)') raise wire.ProcessError("Invalid strength (has to be 128, 192 or 256 bits)")
if storage.is_initialized(): if storage.is_initialized():
raise wire.UnexpectedMessage('Already initialized') raise wire.UnexpectedMessage("Already initialized")
# request new PIN # request new PIN
if msg.pin_protection: if msg.pin_protection:
newpin = await request_pin_confirm(ctx) newpin = await request_pin_confirm(ctx)
else: else:
newpin = '' newpin = ""
# generate and display internal entropy # generate and display internal entropy
internal_ent = random.bytes(32) internal_ent = random.bytes(32)
@ -58,14 +58,12 @@ async def reset_device(ctx, msg):
await show_wrong_entry(ctx) await show_wrong_entry(ctx)
# write PIN into storage # write PIN into storage
if not config.change_pin(pin_to_int(''), pin_to_int(newpin), None): if not config.change_pin(pin_to_int(""), pin_to_int(newpin), None):
raise wire.ProcessError('Could not change PIN') raise wire.ProcessError("Could not change PIN")
# write settings and mnemonic into storage # write settings and mnemonic into storage
storage.load_settings( storage.load_settings(label=msg.label, use_passphrase=msg.passphrase_protection)
label=msg.label, use_passphrase=msg.passphrase_protection) storage.load_mnemonic(mnemonic=mnemonic, needs_backup=msg.skip_backup)
storage.load_mnemonic(
mnemonic=mnemonic, needs_backup=msg.skip_backup)
# show success message. if we skipped backup, it's possible that homescreen # show success message. if we skipped backup, it's possible that homescreen
# is still running, uninterrupted. restart it to pick up new label. # is still running, uninterrupted. restart it to pick up new label.
@ -74,78 +72,64 @@ async def reset_device(ctx, msg):
else: else:
workflow.restartdefault() workflow.restartdefault()
return Success(message='Initialized') return Success(message="Initialized")
def generate_mnemonic(strength: int, def generate_mnemonic(strength: int, int_entropy: bytes, ext_entropy: bytes) -> bytes:
int_entropy: bytes,
ext_entropy: bytes) -> bytes:
ehash = hashlib.sha256() ehash = hashlib.sha256()
ehash.update(int_entropy) ehash.update(int_entropy)
ehash.update(ext_entropy) ehash.update(ext_entropy)
entropy = ehash.digest() entropy = ehash.digest()
mnemonic = bip39.from_data(entropy[:strength // 8]) mnemonic = bip39.from_data(entropy[: strength // 8])
return mnemonic return mnemonic
async def show_warning(ctx): async def show_warning(ctx):
text = Text('Backup your seed', ui.ICON_NOCOPY) text = Text("Backup your seed", ui.ICON_NOCOPY)
text.normal( text.normal(
'Never make a digital', "Never make a digital",
'copy of your recovery', "copy of your recovery",
'seed and never upload', "seed and never upload",
'it online!') "it online!",
)
await require_confirm( await require_confirm(
ctx, ctx, text, ButtonRequestType.ResetDevice, confirm="I understand", cancel=None
text, )
ButtonRequestType.ResetDevice,
confirm='I understand',
cancel=None)
async def show_wrong_entry(ctx): async def show_wrong_entry(ctx):
text = Text('Wrong entry!', ui.ICON_WRONG, icon_color=ui.RED) text = Text("Wrong entry!", ui.ICON_WRONG, icon_color=ui.RED)
text.normal( text.normal("You have entered", "wrong seed word.", "Please check again.")
'You have entered',
'wrong seed word.',
'Please check again.')
await require_confirm( await require_confirm(
ctx, ctx, text, ButtonRequestType.ResetDevice, confirm="Check again", cancel=None
text, )
ButtonRequestType.ResetDevice,
confirm='Check again',
cancel=None)
async def show_success(ctx): async def show_success(ctx):
text = Text('Backup is done!', ui.ICON_CONFIRM, icon_color=ui.GREEN) text = Text("Backup is done!", ui.ICON_CONFIRM, icon_color=ui.GREEN)
text.normal( text.normal(
'Never make a digital', "Never make a digital",
'copy of your recovery', "copy of your recovery",
'seed and never upload', "seed and never upload",
'it online!') "it online!",
)
await require_confirm( await require_confirm(
ctx, ctx, text, ButtonRequestType.ResetDevice, confirm="Finish setup", cancel=None
text, )
ButtonRequestType.ResetDevice,
confirm='Finish setup',
cancel=None)
async def show_entropy(ctx, entropy: bytes): async def show_entropy(ctx, entropy: bytes):
entropy_str = hexlify(entropy).decode() entropy_str = hexlify(entropy).decode()
lines = chunks(entropy_str, 16) lines = chunks(entropy_str, 16)
text = Text('Internal entropy', ui.ICON_RESET) text = Text("Internal entropy", ui.ICON_RESET)
text.mono(*lines) text.mono(*lines)
await require_confirm( await require_confirm(ctx, text, ButtonRequestType.ResetDevice)
ctx,
text,
ButtonRequestType.ResetDevice)
async def show_mnemonic(ctx, mnemonic: str): async def show_mnemonic(ctx, mnemonic: str):
await ctx.call( await ctx.call(
ButtonRequest(code=ButtonRequestType.ResetDevice), MessageType.ButtonAck) ButtonRequest(code=ButtonRequestType.ResetDevice), MessageType.ButtonAck
)
first_page = const(0) first_page = const(0)
words_per_page = const(4) words_per_page = const(4)
words = list(enumerate(mnemonic.split())) words = list(enumerate(mnemonic.split()))
@ -159,8 +143,8 @@ async def show_mnemonic_page(page: int, page_count: int, pages: list):
if __debug__: if __debug__:
debug.reset_current_words = [word for _, word in pages[page]] debug.reset_current_words = [word for _, word in pages[page]]
lines = ['%2d. %s' % (wi + 1, word) for wi, word in pages[page]] lines = ["%2d. %s" % (wi + 1, word) for wi, word in pages[page]]
text = Text('Recovery seed', ui.ICON_RESET) text = Text("Recovery seed", ui.ICON_RESET)
text.mono(*lines) text.mono(*lines)
content = Scrollpage(text, page, page_count) content = Scrollpage(text, page, page_count)
@ -192,6 +176,6 @@ async def check_word(ctx, words: list, index: int):
if __debug__: if __debug__:
debug.reset_word_index = index debug.reset_word_index = index
keyboard = MnemonicKeyboard('Type the %s word:' % format_ordinal(index + 1)) keyboard = MnemonicKeyboard("Type the %s word:" % format_ordinal(index + 1))
result = await ctx.wait(keyboard) result = await ctx.wait(keyboard)
return result == words[index] return result == words[index]

@ -9,13 +9,13 @@ from apps.common.confirm import require_confirm
async def set_u2f_counter(ctx, msg): async def set_u2f_counter(ctx, msg):
if msg.u2f_counter is None: if msg.u2f_counter is None:
raise wire.ProcessError('No value provided') raise wire.ProcessError("No value provided")
text = Text('Set U2F counter', ui.ICON_CONFIG) text = Text("Set U2F counter", ui.ICON_CONFIG)
text.normal('Do you really want to', 'set the U2F counter') text.normal("Do you really want to", "set the U2F counter")
text.bold('to %d?' % msg.u2f_counter) text.bold("to %d?" % msg.u2f_counter)
await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall) await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall)
storage.set_u2f_counter(msg.u2f_counter) storage.set_u2f_counter(msg.u2f_counter)
return Success(message='U2F counter set') return Success(message="U2F counter set")

@ -9,15 +9,18 @@ from apps.common.confirm import require_hold_to_confirm
async def wipe_device(ctx, msg): async def wipe_device(ctx, msg):
text = Text('Wipe device', ui.ICON_WIPE, icon_color=ui.RED) text = Text("Wipe device", ui.ICON_WIPE, icon_color=ui.RED)
text.normal('Do you really want to', 'wipe the device?', '') text.normal("Do you really want to", "wipe the device?", "")
text.bold('All data will be lost.') text.bold("All data will be lost.")
await require_hold_to_confirm(ctx, text, await require_hold_to_confirm(
ctx,
text,
code=ButtonRequestType.WipeDevice, code=ButtonRequestType.WipeDevice,
button_style=ui.BTN_CANCEL, button_style=ui.BTN_CANCEL,
loader_style=ui.LDR_DANGER) loader_style=ui.LDR_DANGER,
)
storage.wipe() storage.wipe()
return Success(message='Device wiped') return Success(message="Device wiped")

@ -4,11 +4,13 @@ from trezor.wire import protobuf_workflow, register
def dispatch_NemGetAddress(*args, **kwargs): def dispatch_NemGetAddress(*args, **kwargs):
from .get_address import get_address from .get_address import get_address
return get_address(*args, **kwargs) return get_address(*args, **kwargs)
def dispatch_NemSignTx(*args, **kwargs): def dispatch_NemSignTx(*args, **kwargs):
from .signing import sign_tx from .signing import sign_tx
return sign_tx(*args, **kwargs) return sign_tx(*args, **kwargs)

@ -30,7 +30,9 @@ async def get_address(ctx, msg):
async def _show_address(ctx, address: str, network: int): async def _show_address(ctx, address: str, network: int):
lines = split_address(address) lines = split_address(address)
text = Text('Confirm address', ui.ICON_RECEIVE, icon_color=ui.GREEN) text = Text("Confirm address", ui.ICON_RECEIVE, icon_color=ui.GREEN)
text.normal('%s network' % get_network_str(network)) text.normal("%s network" % get_network_str(network))
text.mono(*lines) text.mono(*lines)
return await confirm(ctx, text, code=ButtonRequestType.Address, cancel='QR', cancel_style=ui.BTN_KEY) return await confirm(
ctx, text, code=ButtonRequestType.Address, cancel="QR", cancel_style=ui.BTN_KEY
)

@ -3,7 +3,7 @@ from micropython import const
NEM_NETWORK_MAINNET = const(0x68) NEM_NETWORK_MAINNET = const(0x68)
NEM_NETWORK_TESTNET = const(0x98) NEM_NETWORK_TESTNET = const(0x98)
NEM_NETWORK_MIJIN = const(0x60) NEM_NETWORK_MIJIN = const(0x60)
NEM_CURVE = 'ed25519-keccak' NEM_CURVE = "ed25519-keccak"
NEM_TRANSACTION_TYPE_TRANSFER = const(0x0101) NEM_TRANSACTION_TYPE_TRANSFER = const(0x0101)
NEM_TRANSACTION_TYPE_IMPORTANCE_TRANSFER = const(0x0801) NEM_TRANSACTION_TYPE_IMPORTANCE_TRANSFER = const(0x0801)
@ -19,7 +19,7 @@ NEM_MAX_SUPPLY = const(9000000000)
NEM_SALT_SIZE = const(32) NEM_SALT_SIZE = const(32)
AES_BLOCK_SIZE = const(16) AES_BLOCK_SIZE = const(16)
NEM_HASH_ALG = 'keccak' NEM_HASH_ALG = "keccak"
NEM_PUBLIC_KEY_SIZE = const(32) # ed25519 public key NEM_PUBLIC_KEY_SIZE = const(32) # ed25519 public key
NEM_LEVY_PERCENTILE_DIVISOR_ABSOLUTE = const(10000) NEM_LEVY_PERCENTILE_DIVISOR_ABSOLUTE = const(10000)
NEM_MOSAIC_AMOUNT_DIVISOR = const(1000000) NEM_MOSAIC_AMOUNT_DIVISOR = const(1000000)
@ -30,8 +30,8 @@ NEM_MAX_ENCRYPTED_PAYLOAD_SIZE = const(960)
def get_network_str(network: int) -> str: def get_network_str(network: int) -> str:
if network == NEM_NETWORK_MAINNET: if network == NEM_NETWORK_MAINNET:
return 'Mainnet' return "Mainnet"
elif network == NEM_NETWORK_TESTNET: elif network == NEM_NETWORK_TESTNET:
return 'Testnet' return "Testnet"
elif network == NEM_NETWORK_MIJIN: elif network == NEM_NETWORK_MIJIN:
return 'Mijin' return "Mijin"

@ -10,15 +10,17 @@ from apps.common.confirm import require_confirm, require_hold_to_confirm
async def require_confirm_text(ctx, action: str): async def require_confirm_text(ctx, action: str):
words = split_words(action, 18) words = split_words(action, 18)
await require_confirm_content(ctx, 'Confirm action', words) await require_confirm_content(ctx, "Confirm action", words)
async def require_confirm_fee(ctx, action: str, fee: int): async def require_confirm_fee(ctx, action: str, fee: int):
content = ( content = (
ui.NORMAL, action, ui.NORMAL,
ui.BOLD, '%s XEM' % format_amount(fee, NEM_MAX_DIVISIBILITY), action,
ui.BOLD,
"%s XEM" % format_amount(fee, NEM_MAX_DIVISIBILITY),
) )
await require_confirm_content(ctx, 'Confirm fee', content) await require_confirm_content(ctx, "Confirm fee", content)
async def require_confirm_content(ctx, headline: str, content: list): async def require_confirm_content(ctx, headline: str, content: list):
@ -28,10 +30,10 @@ async def require_confirm_content(ctx, headline: str, content: list):
async def require_confirm_final(ctx, fee: int): async def require_confirm_final(ctx, fee: int):
text = Text('Final confirm', ui.ICON_SEND, icon_color=ui.GREEN) text = Text("Final confirm", ui.ICON_SEND, icon_color=ui.GREEN)
text.normal('Sign this transaction') text.normal("Sign this transaction")
text.bold('and pay %s XEM' % format_amount(fee, NEM_MAX_DIVISIBILITY)) text.bold("and pay %s XEM" % format_amount(fee, NEM_MAX_DIVISIBILITY))
text.normal('for network fee?') text.normal("for network fee?")
# we use SignTx, not ConfirmOutput, for compatibility with T1 # we use SignTx, not ConfirmOutput, for compatibility with T1
await require_hold_to_confirm(ctx, text, ButtonRequestType.SignTx) await require_hold_to_confirm(ctx, text, ButtonRequestType.SignTx)
@ -42,5 +44,5 @@ def split_address(address: str):
def trim(payload: str, length: int) -> str: def trim(payload: str, length: int) -> str:
if len(payload) > length: if len(payload) > length:
return payload[:length] + '..' return payload[:length] + ".."
return payload return payload

@ -5,11 +5,15 @@ from trezor.messages.NEMTransactionCommon import NEMTransactionCommon
from . import layout, serialize from . import layout, serialize
async def mosaic_creation(ctx, public_key: bytes, common: NEMTransactionCommon, creation: NEMMosaicCreation) -> bytearray: async def mosaic_creation(
ctx, public_key: bytes, common: NEMTransactionCommon, creation: NEMMosaicCreation
) -> bytearray:
await layout.ask_mosaic_creation(ctx, common, creation) await layout.ask_mosaic_creation(ctx, common, creation)
return serialize.serialize_mosaic_creation(common, creation, public_key) return serialize.serialize_mosaic_creation(common, creation, public_key)
async def supply_change(ctx, public_key: bytes, common: NEMTransactionCommon, change: NEMMosaicSupplyChange) -> bytearray: async def supply_change(
ctx, public_key: bytes, common: NEMTransactionCommon, change: NEMMosaicSupplyChange
) -> bytearray:
await layout.ask_supply_change(ctx, common, change) await layout.ask_supply_change(ctx, common, change)
return serialize.serialize_mosaic_supply_change(common, change, public_key) return serialize.serialize_mosaic_supply_change(common, change, public_key)

@ -24,39 +24,55 @@ from ..layout import (
) )
async def ask_mosaic_creation(ctx, common: NEMTransactionCommon, creation: NEMMosaicCreation): async def ask_mosaic_creation(
await require_confirm_content(ctx, 'Create mosaic', _creation_message(creation)) ctx, common: NEMTransactionCommon, creation: NEMMosaicCreation
):
await require_confirm_content(ctx, "Create mosaic", _creation_message(creation))
await _require_confirm_properties(ctx, creation.definition) await _require_confirm_properties(ctx, creation.definition)
await require_confirm_fee(ctx, 'Confirm creation fee', creation.fee) await require_confirm_fee(ctx, "Confirm creation fee", creation.fee)
await require_confirm_final(ctx, common.fee) await require_confirm_final(ctx, common.fee)
async def ask_supply_change(ctx, common: NEMTransactionCommon, change: NEMMosaicSupplyChange): async def ask_supply_change(
await require_confirm_content(ctx, 'Supply change', _supply_message(change)) ctx, common: NEMTransactionCommon, change: NEMMosaicSupplyChange
):
await require_confirm_content(ctx, "Supply change", _supply_message(change))
if change.type == NEMSupplyChangeType.SupplyChange_Decrease: if change.type == NEMSupplyChangeType.SupplyChange_Decrease:
msg = 'Decrease supply by ' + str(change.delta) + ' whole units?' msg = "Decrease supply by " + str(change.delta) + " whole units?"
elif change.type == NEMSupplyChangeType.SupplyChange_Increase: elif change.type == NEMSupplyChangeType.SupplyChange_Increase:
msg = 'Increase supply by ' + str(change.delta) + ' whole units?' msg = "Increase supply by " + str(change.delta) + " whole units?"
else: else:
raise ValueError('Invalid supply change type') raise ValueError("Invalid supply change type")
await require_confirm_text(ctx, msg) await require_confirm_text(ctx, msg)
await require_confirm_final(ctx, common.fee) await require_confirm_final(ctx, common.fee)
def _creation_message(mosaic_creation): def _creation_message(mosaic_creation):
return [ui.NORMAL, 'Create mosaic', return [
ui.BOLD, mosaic_creation.definition.mosaic, ui.NORMAL,
ui.NORMAL, 'under namespace', "Create mosaic",
ui.BOLD, mosaic_creation.definition.namespace] ui.BOLD,
mosaic_creation.definition.mosaic,
ui.NORMAL,
"under namespace",
ui.BOLD,
mosaic_creation.definition.namespace,
]
def _supply_message(supply_change): def _supply_message(supply_change):
return [ui.NORMAL, 'Modify supply for', return [
ui.BOLD, supply_change.mosaic, ui.NORMAL,
ui.NORMAL, 'under namespace', "Modify supply for",
ui.BOLD, supply_change.namespace] ui.BOLD,
supply_change.mosaic,
ui.NORMAL,
"under namespace",
ui.BOLD,
supply_change.namespace,
]
async def _require_confirm_properties(ctx, definition: NEMMosaicDefinition): async def _require_confirm_properties(ctx, definition: NEMMosaicDefinition):
@ -81,64 +97,64 @@ def _get_mosaic_properties(definition: NEMMosaicDefinition):
# description # description
if definition.description: if definition.description:
t = Text('Confirm properties', ui.ICON_SEND) t = Text("Confirm properties", ui.ICON_SEND)
t.bold('Description:') t.bold("Description:")
t.normal(*split_words(trim(definition.description, 70), 22)) t.normal(*split_words(trim(definition.description, 70), 22))
properties.append(t) properties.append(t)
# transferable # transferable
if definition.transferable: if definition.transferable:
transferable = 'Yes' transferable = "Yes"
else: else:
transferable = 'No' transferable = "No"
t = Text('Confirm properties', ui.ICON_SEND) t = Text("Confirm properties", ui.ICON_SEND)
t.bold('Transferable?') t.bold("Transferable?")
t.normal(transferable) t.normal(transferable)
properties.append(t) properties.append(t)
# mutable_supply # mutable_supply
if definition.mutable_supply: if definition.mutable_supply:
imm = 'mutable' imm = "mutable"
else: else:
imm = 'immutable' imm = "immutable"
if definition.supply: if definition.supply:
t = Text('Confirm properties', ui.ICON_SEND) t = Text("Confirm properties", ui.ICON_SEND)
t.bold('Initial supply:') t.bold("Initial supply:")
t.normal(str(definition.supply), imm) t.normal(str(definition.supply), imm)
else: else:
t = Text('Confirm properties', ui.ICON_SEND) t = Text("Confirm properties", ui.ICON_SEND)
t.bold('Initial supply:') t.bold("Initial supply:")
t.normal(imm) t.normal(imm)
properties.append(t) properties.append(t)
# levy # levy
if definition.levy: if definition.levy:
t = Text('Confirm properties', ui.ICON_SEND) t = Text("Confirm properties", ui.ICON_SEND)
t.bold('Levy recipient:') t.bold("Levy recipient:")
t.mono(*split_address(definition.levy_address)) t.mono(*split_address(definition.levy_address))
properties.append(t) properties.append(t)
t = Text('Confirm properties', ui.ICON_SEND) t = Text("Confirm properties", ui.ICON_SEND)
t.bold('Levy fee:') t.bold("Levy fee:")
t.normal(str(definition.fee)) t.normal(str(definition.fee))
t.bold('Levy divisibility:') t.bold("Levy divisibility:")
t.normal(str(definition.divisibility)) t.normal(str(definition.divisibility))
properties.append(t) properties.append(t)
t = Text('Confirm properties', ui.ICON_SEND) t = Text("Confirm properties", ui.ICON_SEND)
t.bold('Levy namespace:') t.bold("Levy namespace:")
t.normal(definition.levy_namespace) t.normal(definition.levy_namespace)
t.bold('Levy mosaic:') t.bold("Levy mosaic:")
t.normal(definition.levy_mosaic) t.normal(definition.levy_mosaic)
properties.append(t) properties.append(t)
if definition.levy == NEMMosaicLevy.MosaicLevy_Absolute: if definition.levy == NEMMosaicLevy.MosaicLevy_Absolute:
levy_type = 'absolute' levy_type = "absolute"
else: else:
levy_type = 'percentile' levy_type = "percentile"
t = Text('Confirm properties', ui.ICON_SEND) t = Text("Confirm properties", ui.ICON_SEND)
t.bold('Levy type:') t.bold("Levy type:")
t.normal(levy_type) t.normal(levy_type)
properties.append(t) properties.append(t)

@ -9,30 +9,51 @@ from ..helpers import (
from ..writers import write_bytes_with_length, write_common, write_uint32, write_uint64 from ..writers import write_bytes_with_length, write_common, write_uint32, write_uint64
def serialize_mosaic_creation(common: NEMTransactionCommon, creation: NEMMosaicCreation, public_key: bytes): def serialize_mosaic_creation(
w = write_common(common, bytearray(public_key), NEM_TRANSACTION_TYPE_MOSAIC_CREATION) common: NEMTransactionCommon, creation: NEMMosaicCreation, public_key: bytes
):
w = write_common(
common, bytearray(public_key), NEM_TRANSACTION_TYPE_MOSAIC_CREATION
)
mosaics_w = bytearray() mosaics_w = bytearray()
write_bytes_with_length(mosaics_w, bytearray(public_key)) write_bytes_with_length(mosaics_w, bytearray(public_key))
identifier_length = 4 + len(creation.definition.namespace) + 4 + len(creation.definition.mosaic) identifier_length = (
4 + len(creation.definition.namespace) + 4 + len(creation.definition.mosaic)
)
write_uint32(mosaics_w, identifier_length) write_uint32(mosaics_w, identifier_length)
write_bytes_with_length(mosaics_w, bytearray(creation.definition.namespace)) write_bytes_with_length(mosaics_w, bytearray(creation.definition.namespace))
write_bytes_with_length(mosaics_w, bytearray(creation.definition.mosaic)) write_bytes_with_length(mosaics_w, bytearray(creation.definition.mosaic))
write_bytes_with_length(mosaics_w, bytearray(creation.definition.description)) write_bytes_with_length(mosaics_w, bytearray(creation.definition.description))
write_uint32(mosaics_w, 4) # number of properties write_uint32(mosaics_w, 4) # number of properties
_write_property(mosaics_w, 'divisibility', creation.definition.divisibility) _write_property(mosaics_w, "divisibility", creation.definition.divisibility)
_write_property(mosaics_w, 'initialSupply', creation.definition.supply) _write_property(mosaics_w, "initialSupply", creation.definition.supply)
_write_property(mosaics_w, 'supplyMutable', creation.definition.mutable_supply) _write_property(mosaics_w, "supplyMutable", creation.definition.mutable_supply)
_write_property(mosaics_w, 'transferable', creation.definition.transferable) _write_property(mosaics_w, "transferable", creation.definition.transferable)
if creation.definition.levy: if creation.definition.levy:
levy_identifier_length = 4 + len(creation.definition.levy_namespace) + 4 + len(creation.definition.levy_mosaic) levy_identifier_length = (
write_uint32(mosaics_w, 4 + 4 + len(creation.definition.levy_address) + 4 + levy_identifier_length + 8) 4
+ len(creation.definition.levy_namespace)
+ 4
+ len(creation.definition.levy_mosaic)
)
write_uint32(
mosaics_w,
4
+ 4
+ len(creation.definition.levy_address)
+ 4
+ levy_identifier_length
+ 8,
)
write_uint32(mosaics_w, creation.definition.levy) write_uint32(mosaics_w, creation.definition.levy)
write_bytes_with_length(mosaics_w, bytearray(creation.definition.levy_address)) write_bytes_with_length(mosaics_w, bytearray(creation.definition.levy_address))
write_uint32(mosaics_w, levy_identifier_length) write_uint32(mosaics_w, levy_identifier_length)
write_bytes_with_length(mosaics_w, bytearray(creation.definition.levy_namespace)) write_bytes_with_length(
mosaics_w, bytearray(creation.definition.levy_namespace)
)
write_bytes_with_length(mosaics_w, bytearray(creation.definition.levy_mosaic)) write_bytes_with_length(mosaics_w, bytearray(creation.definition.levy_mosaic))
write_uint64(mosaics_w, creation.definition.fee) write_uint64(mosaics_w, creation.definition.fee)
else: else:
@ -47,8 +68,12 @@ def serialize_mosaic_creation(common: NEMTransactionCommon, creation: NEMMosaicC
return w return w
def serialize_mosaic_supply_change(common: NEMTransactionCommon, change: NEMMosaicSupplyChange, public_key: bytes): def serialize_mosaic_supply_change(
w = write_common(common, bytearray(public_key), NEM_TRANSACTION_TYPE_MOSAIC_SUPPLY_CHANGE) common: NEMTransactionCommon, change: NEMMosaicSupplyChange, public_key: bytes
):
w = write_common(
common, bytearray(public_key), NEM_TRANSACTION_TYPE_MOSAIC_SUPPLY_CHANGE
)
identifier_length = 4 + len(change.namespace) + 4 + len(change.mosaic) identifier_length = 4 + len(change.namespace) + 4 + len(change.mosaic)
write_uint32(w, identifier_length) write_uint32(w, identifier_length)
@ -62,19 +87,19 @@ def serialize_mosaic_supply_change(common: NEMTransactionCommon, change: NEMMosa
def _write_property(w: bytearray, name: str, value): def _write_property(w: bytearray, name: str, value):
if value is None: if value is None:
if name in ('divisibility', 'initialSupply'): if name in ("divisibility", "initialSupply"):
value = 0 value = 0
elif name in ('supplyMutable', 'transferable'): elif name in ("supplyMutable", "transferable"):
value = False value = False
if type(value) == bool: if type(value) == bool:
if value: if value:
value = 'true' value = "true"
else: else:
value = 'false' value = "false"
elif type(value) == int: elif type(value) == int:
value = str(value) value = str(value)
elif type(value) != str: elif type(value) != str:
raise ValueError('Incompatible value type') raise ValueError("Incompatible value type")
write_uint32(w, 4 + len(name) + 4 + len(value)) write_uint32(w, 4 + len(name) + 4 + len(value))
write_bytes_with_length(w, bytearray(name)) write_bytes_with_length(w, bytearray(name))
write_bytes_with_length(w, bytearray(value)) write_bytes_with_length(w, bytearray(value))

@ -13,15 +13,19 @@ def initiate(public_key, common: NEMTransactionCommon, inner_tx: bytes) -> bytes
return serialize.serialize_multisig(common, public_key, inner_tx) return serialize.serialize_multisig(common, public_key, inner_tx)
def cosign(public_key, common: NEMTransactionCommon, inner_tx: bytes, signer: bytes) -> bytes: def cosign(
public_key, common: NEMTransactionCommon, inner_tx: bytes, signer: bytes
) -> bytes:
return serialize.serialize_multisig_signature(common, public_key, inner_tx, signer) return serialize.serialize_multisig_signature(common, public_key, inner_tx, signer)
async def aggregate_modification(ctx, async def aggregate_modification(
public_key: bytes, ctx,
common: NEMTransactionCommon, public_key: bytes,
aggr: NEMAggregateModification, common: NEMTransactionCommon,
multisig: bool): aggr: NEMAggregateModification,
multisig: bool,
):
await layout.ask_aggregate_modification(ctx, common, aggr, multisig) await layout.ask_aggregate_modification(ctx, common, aggr, multisig)
w = serialize.serialize_aggregate_modification(common, aggr, public_key) w = serialize.serialize_aggregate_modification(common, aggr, public_key)

@ -21,36 +21,38 @@ from ..layout import (
async def ask_multisig(ctx, msg: NEMSignTx): async def ask_multisig(ctx, msg: NEMSignTx):
address = nem.compute_address(msg.multisig.signer, msg.transaction.network) address = nem.compute_address(msg.multisig.signer, msg.transaction.network)
if msg.cosigning: if msg.cosigning:
await _require_confirm_address(ctx, 'Cosign transaction for', address) await _require_confirm_address(ctx, "Cosign transaction for", address)
else: else:
await _require_confirm_address(ctx, 'Initiate transaction for', address) await _require_confirm_address(ctx, "Initiate transaction for", address)
await require_confirm_fee(ctx, 'Confirm multisig fee', msg.transaction.fee) await require_confirm_fee(ctx, "Confirm multisig fee", msg.transaction.fee)
async def ask_aggregate_modification(ctx, common: NEMTransactionCommon, mod: NEMAggregateModification, multisig: bool): async def ask_aggregate_modification(
ctx, common: NEMTransactionCommon, mod: NEMAggregateModification, multisig: bool
):
if not multisig: if not multisig:
await require_confirm_text(ctx, 'Convert account to multisig account?') await require_confirm_text(ctx, "Convert account to multisig account?")
for m in mod.modifications: for m in mod.modifications:
if m.type == NEMModificationType.CosignatoryModification_Add: if m.type == NEMModificationType.CosignatoryModification_Add:
action = 'Add' action = "Add"
else: else:
action = 'Remove' action = "Remove"
address = nem.compute_address(m.public_key, common.network) address = nem.compute_address(m.public_key, common.network)
await _require_confirm_address(ctx, action + ' cosignatory', address) await _require_confirm_address(ctx, action + " cosignatory", address)
if mod.relative_change: if mod.relative_change:
if multisig: if multisig:
action = 'Modify the number of cosignatories by ' action = "Modify the number of cosignatories by "
else: else:
action = 'Set minimum cosignatories to ' action = "Set minimum cosignatories to "
await require_confirm_text(ctx, action + str(mod.relative_change) + '?') await require_confirm_text(ctx, action + str(mod.relative_change) + "?")
await require_confirm_final(ctx, common.fee) await require_confirm_final(ctx, common.fee)
async def _require_confirm_address(ctx, action: str, address: str): async def _require_confirm_address(ctx, action: str, address: str):
text = Text('Confirm address', ui.ICON_SEND, icon_color=ui.GREEN) text = Text("Confirm address", ui.ICON_SEND, icon_color=ui.GREEN)
text.normal(action) text.normal(action)
text.mono(*split_address(address)) text.mono(*split_address(address))
await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput) await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput)

@ -16,10 +16,16 @@ def serialize_multisig(common: NEMTransactionCommon, public_key: bytes, inner: b
return w return w
def serialize_multisig_signature(common: NEMTransactionCommon, public_key: bytes, def serialize_multisig_signature(
inner: bytes, address_public_key: bytes): common: NEMTransactionCommon,
public_key: bytes,
inner: bytes,
address_public_key: bytes,
):
address = nem.compute_address(address_public_key, common.network) address = nem.compute_address(address_public_key, common.network)
w = write_common(common, bytearray(public_key), NEM_TRANSACTION_TYPE_MULTISIG_SIGNATURE) w = write_common(
common, bytearray(public_key), NEM_TRANSACTION_TYPE_MULTISIG_SIGNATURE
)
digest = hashlib.sha3_256(inner).digest(True) digest = hashlib.sha3_256(inner).digest(True)
write_uint32(w, 4 + len(digest)) write_uint32(w, 4 + len(digest))
@ -28,20 +34,26 @@ def serialize_multisig_signature(common: NEMTransactionCommon, public_key: bytes
return w return w
def serialize_aggregate_modification(common: NEMTransactionCommon, mod: NEMAggregateModification, public_key: bytes): def serialize_aggregate_modification(
common: NEMTransactionCommon, mod: NEMAggregateModification, public_key: bytes
):
version = common.network << 24 | 1 version = common.network << 24 | 1
if mod.relative_change: if mod.relative_change:
version = common.network << 24 | 2 version = common.network << 24 | 2
w = write_common(common, w = write_common(
bytearray(public_key), common,
NEM_TRANSACTION_TYPE_AGGREGATE_MODIFICATION, bytearray(public_key),
version) NEM_TRANSACTION_TYPE_AGGREGATE_MODIFICATION,
version,
)
write_uint32(w, len(mod.modifications)) write_uint32(w, len(mod.modifications))
return w return w
def serialize_cosignatory_modification(w: bytearray, type: int, cosignatory_pubkey: bytes): def serialize_cosignatory_modification(
w: bytearray, type: int, cosignatory_pubkey: bytes
):
write_uint32(w, 4 + 4 + len(cosignatory_pubkey)) write_uint32(w, 4 + 4 + len(cosignatory_pubkey))
write_uint32(w, type) write_uint32(w, type)
write_bytes_with_length(w, bytearray(cosignatory_pubkey)) write_bytes_with_length(w, bytearray(cosignatory_pubkey))

@ -4,6 +4,11 @@ from trezor.messages.NEMTransactionCommon import NEMTransactionCommon
from . import layout, serialize from . import layout, serialize
async def namespace(ctx, public_key: bytes, common: NEMTransactionCommon, namespace: NEMProvisionNamespace) -> bytearray: async def namespace(
ctx,
public_key: bytes,
common: NEMTransactionCommon,
namespace: NEMProvisionNamespace,
) -> bytearray:
await layout.ask_provision_namespace(ctx, common, namespace) await layout.ask_provision_namespace(ctx, common, namespace)
return serialize.serialize_provision_namespace(common, namespace, public_key) return serialize.serialize_provision_namespace(common, namespace, public_key)

@ -4,18 +4,25 @@ from trezor.messages import NEMProvisionNamespace, NEMTransactionCommon
from ..layout import require_confirm_content, require_confirm_fee, require_confirm_final from ..layout import require_confirm_content, require_confirm_fee, require_confirm_final
async def ask_provision_namespace(ctx, common: NEMTransactionCommon, namespace: NEMProvisionNamespace): async def ask_provision_namespace(
ctx, common: NEMTransactionCommon, namespace: NEMProvisionNamespace
):
if namespace.parent: if namespace.parent:
content = (ui.NORMAL, 'Create namespace', content = (
ui.BOLD, namespace.namespace, ui.NORMAL,
ui.NORMAL, 'under namespace', "Create namespace",
ui.BOLD, namespace.parent) ui.BOLD,
await require_confirm_content(ctx, 'Confirm namespace', content) namespace.namespace,
ui.NORMAL,
"under namespace",
ui.BOLD,
namespace.parent,
)
await require_confirm_content(ctx, "Confirm namespace", content)
else: else:
content = (ui.NORMAL, 'Create namespace', content = (ui.NORMAL, "Create namespace", ui.BOLD, namespace.namespace)
ui.BOLD, namespace.namespace) await require_confirm_content(ctx, "Confirm namespace", content)
await require_confirm_content(ctx, 'Confirm namespace', content)
await require_confirm_fee(ctx, 'Confirm rental fee', namespace.fee) await require_confirm_fee(ctx, "Confirm rental fee", namespace.fee)
await require_confirm_final(ctx, common.fee) await require_confirm_final(ctx, common.fee)

@ -5,10 +5,12 @@ from ..helpers import NEM_TRANSACTION_TYPE_PROVISION_NAMESPACE
from ..writers import write_bytes_with_length, write_common, write_uint32, write_uint64 from ..writers import write_bytes_with_length, write_common, write_uint32, write_uint64
def serialize_provision_namespace(common: NEMTransactionCommon, namespace: NEMProvisionNamespace, public_key: bytes) -> bytearray: def serialize_provision_namespace(
tx = write_common(common, common: NEMTransactionCommon, namespace: NEMProvisionNamespace, public_key: bytes
bytearray(public_key), ) -> bytearray:
NEM_TRANSACTION_TYPE_PROVISION_NAMESPACE) tx = write_common(
common, bytearray(public_key), NEM_TRANSACTION_TYPE_PROVISION_NAMESPACE
)
write_bytes_with_length(tx, bytearray(namespace.sink)) write_bytes_with_length(tx, bytearray(namespace.sink))
write_uint64(tx, namespace.fee) write_uint64(tx, namespace.fee)

@ -30,16 +30,26 @@ async def sign_tx(ctx, msg: NEMSignTx):
elif msg.supply_change: elif msg.supply_change:
tx = await mosaic.supply_change(ctx, public_key, common, msg.supply_change) tx = await mosaic.supply_change(ctx, public_key, common, msg.supply_change)
elif msg.aggregate_modification: elif msg.aggregate_modification:
tx = await multisig.aggregate_modification(ctx, public_key, common, msg.aggregate_modification, msg.multisig is not None) tx = await multisig.aggregate_modification(
ctx,
public_key,
common,
msg.aggregate_modification,
msg.multisig is not None,
)
elif msg.importance_transfer: elif msg.importance_transfer:
tx = await transfer.importance_transfer(ctx, public_key, common, msg.importance_transfer) tx = await transfer.importance_transfer(
ctx, public_key, common, msg.importance_transfer
)
else: else:
raise ValueError('No transaction provided') raise ValueError("No transaction provided")
if msg.multisig: if msg.multisig:
# wrap transaction in multisig wrapper # wrap transaction in multisig wrapper
if msg.cosigning: if msg.cosigning:
tx = multisig.cosign(_get_public_key(node), msg.transaction, tx, msg.multisig.signer) tx = multisig.cosign(
_get_public_key(node), msg.transaction, tx, msg.multisig.signer
)
else: else:
tx = multisig.initiate(_get_public_key(node), msg.transaction, tx) tx = multisig.initiate(_get_public_key(node), msg.transaction, tx)

@ -5,7 +5,9 @@ from trezor.messages.NEMTransfer import NEMTransfer
from . import layout, serialize from . import layout, serialize
async def transfer(ctx, public_key: bytes, common: NEMTransactionCommon, transfer: NEMTransfer, node): async def transfer(
ctx, public_key: bytes, common: NEMTransactionCommon, transfer: NEMTransfer, node
):
transfer.mosaics = serialize.canonicalize_mosaics(transfer.mosaics) transfer.mosaics = serialize.canonicalize_mosaics(transfer.mosaics)
payload, encrypted = serialize.get_transfer_payload(transfer, node) payload, encrypted = serialize.get_transfer_payload(transfer, node)
@ -17,6 +19,8 @@ async def transfer(ctx, public_key: bytes, common: NEMTransactionCommon, transfe
return w return w
async def importance_transfer(ctx, public_key: bytes, common: NEMTransactionCommon, imp: NEMImportanceTransfer): async def importance_transfer(
ctx, public_key: bytes, common: NEMTransactionCommon, imp: NEMImportanceTransfer
):
await layout.ask_importance_transfer(ctx, common, imp) await layout.ask_importance_transfer(ctx, common, imp)
return serialize.serialize_importance_transfer(common, imp, public_key) return serialize.serialize_importance_transfer(common, imp, public_key)

@ -22,7 +22,13 @@ from ..mosaic.helpers import get_mosaic_definition, is_nem_xem_mosaic
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
async def ask_transfer(ctx, common: NEMTransactionCommon, transfer: NEMTransfer, payload: bytes, encrypted: bool): async def ask_transfer(
ctx,
common: NEMTransactionCommon,
transfer: NEMTransfer,
payload: bytes,
encrypted: bool,
):
if payload: if payload:
await _require_confirm_payload(ctx, transfer.payload, encrypted) await _require_confirm_payload(ctx, transfer.payload, encrypted)
for mosaic in transfer.mosaics: for mosaic in transfer.mosaics:
@ -31,7 +37,9 @@ async def ask_transfer(ctx, common: NEMTransactionCommon, transfer: NEMTransfer,
await require_confirm_final(ctx, common.fee) await require_confirm_final(ctx, common.fee)
async def ask_transfer_mosaic(ctx, common: NEMTransactionCommon, transfer: NEMTransfer, mosaic: NEMMosaic): async def ask_transfer_mosaic(
ctx, common: NEMTransactionCommon, transfer: NEMTransfer, mosaic: NEMMosaic
):
if is_nem_xem_mosaic(mosaic.namespace, mosaic.mosaic): if is_nem_xem_mosaic(mosaic.namespace, mosaic.mosaic):
return return
@ -39,31 +47,38 @@ async def ask_transfer_mosaic(ctx, common: NEMTransactionCommon, transfer: NEMTr
mosaic_quantity = mosaic.quantity * transfer.amount / NEM_MOSAIC_AMOUNT_DIVISOR mosaic_quantity = mosaic.quantity * transfer.amount / NEM_MOSAIC_AMOUNT_DIVISOR
if definition: if definition:
msg = Text('Confirm mosaic', ui.ICON_SEND, icon_color=ui.GREEN) msg = Text("Confirm mosaic", ui.ICON_SEND, icon_color=ui.GREEN)
msg.normal('Confirm transfer of') msg.normal("Confirm transfer of")
msg.bold(format_amount(mosaic_quantity, definition['divisibility']) + definition['ticker']) msg.bold(
msg.normal('of') format_amount(mosaic_quantity, definition["divisibility"])
msg.bold(definition['name']) + definition["ticker"]
)
msg.normal("of")
msg.bold(definition["name"])
await require_confirm(ctx, msg, ButtonRequestType.ConfirmOutput) await require_confirm(ctx, msg, ButtonRequestType.ConfirmOutput)
if 'levy' in definition and 'fee' in definition: if "levy" in definition and "fee" in definition:
levy_msg = _get_levy_msg(definition, mosaic_quantity, common.network) levy_msg = _get_levy_msg(definition, mosaic_quantity, common.network)
msg = Text('Confirm mosaic', ui.ICON_SEND, icon_color=ui.GREEN) msg = Text("Confirm mosaic", ui.ICON_SEND, icon_color=ui.GREEN)
msg.normal('Confirm mosaic', 'levy fee of') msg.normal("Confirm mosaic", "levy fee of")
msg.bold(levy_msg) msg.bold(levy_msg)
await require_confirm(ctx, msg, ButtonRequestType.ConfirmOutput) await require_confirm(ctx, msg, ButtonRequestType.ConfirmOutput)
else: else:
msg = Text('Confirm mosaic', ui.ICON_SEND, icon_color=ui.RED) msg = Text("Confirm mosaic", ui.ICON_SEND, icon_color=ui.RED)
msg.bold('Unknown mosaic!') msg.bold("Unknown mosaic!")
msg.normal(*split_words('Divisibility and levy cannot be shown for unknown mosaics', 22)) msg.normal(
*split_words(
"Divisibility and levy cannot be shown for unknown mosaics", 22
)
)
await require_confirm(ctx, msg, ButtonRequestType.ConfirmOutput) await require_confirm(ctx, msg, ButtonRequestType.ConfirmOutput)
msg = Text('Confirm mosaic', ui.ICON_SEND, icon_color=ui.GREEN) msg = Text("Confirm mosaic", ui.ICON_SEND, icon_color=ui.GREEN)
msg.normal('Confirm transfer of') msg.normal("Confirm transfer of")
msg.bold('%s raw units' % mosaic_quantity) msg.bold("%s raw units" % mosaic_quantity)
msg.normal('of') msg.normal("of")
msg.bold('%s.%s' % (mosaic.namespace, mosaic.mosaic)) msg.bold("%s.%s" % (mosaic.namespace, mosaic.mosaic))
await require_confirm(ctx, msg, ButtonRequestType.ConfirmOutput) await require_confirm(ctx, msg, ButtonRequestType.ConfirmOutput)
@ -81,47 +96,50 @@ def _get_xem_amount(transfer: NEMTransfer):
def _get_levy_msg(mosaic_definition, quantity: int, network: int) -> str: def _get_levy_msg(mosaic_definition, quantity: int, network: int) -> str:
levy_definition = get_mosaic_definition( levy_definition = get_mosaic_definition(
mosaic_definition['levy_namespace'], mosaic_definition["levy_namespace"], mosaic_definition["levy_mosaic"], network
mosaic_definition['levy_mosaic'], )
network) if mosaic_definition["levy"] == NEMMosaicLevy.MosaicLevy_Absolute:
if mosaic_definition['levy'] == NEMMosaicLevy.MosaicLevy_Absolute: levy_fee = mosaic_definition["fee"]
levy_fee = mosaic_definition['fee']
else: else:
levy_fee = quantity * mosaic_definition['fee'] / NEM_LEVY_PERCENTILE_DIVISOR_ABSOLUTE levy_fee = (
return format_amount( quantity * mosaic_definition["fee"] / NEM_LEVY_PERCENTILE_DIVISOR_ABSOLUTE
levy_fee, )
levy_definition['divisibility'] return (
) + levy_definition['ticker'] format_amount(levy_fee, levy_definition["divisibility"])
+ levy_definition["ticker"]
)
async def ask_importance_transfer(ctx, common: NEMTransactionCommon, imp: NEMImportanceTransfer):
async def ask_importance_transfer(
ctx, common: NEMTransactionCommon, imp: NEMImportanceTransfer
):
if imp.mode == NEMImportanceTransferMode.ImportanceTransfer_Activate: if imp.mode == NEMImportanceTransferMode.ImportanceTransfer_Activate:
m = 'Activate' m = "Activate"
else: else:
m = 'Deactivate' m = "Deactivate"
await require_confirm_text(ctx, m + ' remote harvesting?') await require_confirm_text(ctx, m + " remote harvesting?")
await require_confirm_final(ctx, common.fee) await require_confirm_final(ctx, common.fee)
async def _require_confirm_transfer(ctx, recipient, value): async def _require_confirm_transfer(ctx, recipient, value):
text = Text('Confirm transfer', ui.ICON_SEND, icon_color=ui.GREEN) text = Text("Confirm transfer", ui.ICON_SEND, icon_color=ui.GREEN)
text.bold('Send %s XEM' % format_amount(value, NEM_MAX_DIVISIBILITY)) text.bold("Send %s XEM" % format_amount(value, NEM_MAX_DIVISIBILITY))
text.normal('to') text.normal("to")
text.mono(*split_address(recipient)) text.mono(*split_address(recipient))
await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput) await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput)
async def _require_confirm_payload(ctx, payload: bytes, encrypt=False): async def _require_confirm_payload(ctx, payload: bytes, encrypt=False):
payload = str(payload, 'utf-8') payload = str(payload, "utf-8")
if len(payload) > 48: if len(payload) > 48:
payload = payload[:48] + '..' payload = payload[:48] + ".."
if encrypt: if encrypt:
text = Text('Confirm payload', ui.ICON_SEND, icon_color=ui.GREEN) text = Text("Confirm payload", ui.ICON_SEND, icon_color=ui.GREEN)
text.bold('Encrypted:') text.bold("Encrypted:")
text.normal(*split_words(payload, 22)) text.normal(*split_words(payload, 22))
else: else:
text = Text('Confirm payload', ui.ICON_SEND, icon_color=ui.RED) text = Text("Confirm payload", ui.ICON_SEND, icon_color=ui.RED)
text.bold('Unencrypted:') text.bold("Unencrypted:")
text.normal(*split_words(payload, 22)) text.normal(*split_words(payload, 22))
await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput) await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput)

@ -13,14 +13,19 @@ from ..helpers import (
from ..writers import write_bytes_with_length, write_common, write_uint32, write_uint64 from ..writers import write_bytes_with_length, write_common, write_uint32, write_uint64
def serialize_transfer(common: NEMTransactionCommon, def serialize_transfer(
transfer: NEMTransfer, common: NEMTransactionCommon,
public_key: bytes, transfer: NEMTransfer,
payload: bytes = None, public_key: bytes,
encrypted: bool = False) -> bytearray: payload: bytes = None,
tx = write_common(common, bytearray(public_key), encrypted: bool = False,
NEM_TRANSACTION_TYPE_TRANSFER, ) -> bytearray:
_get_version(common.network, transfer.mosaics)) tx = write_common(
common,
bytearray(public_key),
NEM_TRANSACTION_TYPE_TRANSFER,
_get_version(common.network, transfer.mosaics),
)
write_bytes_with_length(tx, bytearray(transfer.recipient)) write_bytes_with_length(tx, bytearray(transfer.recipient))
write_uint64(tx, transfer.amount) write_uint64(tx, transfer.amount)
@ -52,11 +57,12 @@ def serialize_mosaic(w: bytearray, namespace: str, mosaic: str, quantity: int):
write_uint64(w, quantity) write_uint64(w, quantity)
def serialize_importance_transfer(common: NEMTransactionCommon, def serialize_importance_transfer(
imp: NEMImportanceTransfer, common: NEMTransactionCommon, imp: NEMImportanceTransfer, public_key: bytes
public_key: bytes) -> bytearray: ) -> bytearray:
w = write_common(common, bytearray(public_key), w = write_common(
NEM_TRANSACTION_TYPE_IMPORTANCE_TRANSFER) common, bytearray(public_key), NEM_TRANSACTION_TYPE_IMPORTANCE_TRANSFER
)
write_uint32(w, imp.mode) write_uint32(w, imp.mode)
write_bytes_with_length(w, bytearray(imp.public_key)) write_bytes_with_length(w, bytearray(imp.public_key))
@ -68,7 +74,7 @@ def get_transfer_payload(transfer: NEMTransfer, node) -> [bytes, bool]:
encrypted = False encrypted = False
if transfer.public_key is not None: if transfer.public_key is not None:
if payload is None: if payload is None:
raise ValueError('Public key provided but no payload to encrypt') raise ValueError("Public key provided but no payload to encrypt")
payload = _encrypt(node, transfer.public_key, transfer.payload) payload = _encrypt(node, transfer.public_key, transfer.payload)
encrypted = True encrypted = True

@ -26,7 +26,7 @@ from .helpers import (
def validate(msg: NEMSignTx): def validate(msg: NEMSignTx):
if msg.transaction is None: if msg.transaction is None:
raise ProcessError('No common provided') raise ProcessError("No common provided")
_validate_single_tx(msg) _validate_single_tx(msg)
_validate_common(msg.transaction) _validate_common(msg.transaction)
@ -35,7 +35,7 @@ def validate(msg: NEMSignTx):
_validate_common(msg.multisig, True) _validate_common(msg.multisig, True)
_validate_multisig(msg.multisig, msg.transaction.network) _validate_multisig(msg.multisig, msg.transaction.network)
if not msg.multisig and msg.cosigning: if not msg.multisig and msg.cosigning:
raise ProcessError('No multisig transaction to cosign') raise ProcessError("No multisig transaction to cosign")
if msg.transfer: if msg.transfer:
_validate_transfer(msg.transfer, msg.transaction.network) _validate_transfer(msg.transfer, msg.transaction.network)
@ -46,7 +46,9 @@ def validate(msg: NEMSignTx):
if msg.supply_change: if msg.supply_change:
_validate_supply_change(msg.supply_change) _validate_supply_change(msg.supply_change)
if msg.aggregate_modification: if msg.aggregate_modification:
_validate_aggregate_modification(msg.aggregate_modification, msg.multisig is None) _validate_aggregate_modification(
msg.aggregate_modification, msg.multisig is None
)
if msg.importance_transfer: if msg.importance_transfer:
_validate_importance_transfer(msg.importance_transfer) _validate_importance_transfer(msg.importance_transfer)
@ -55,23 +57,24 @@ def validate_network(network: int) -> int:
if network is None: if network is None:
return NEM_NETWORK_MAINNET return NEM_NETWORK_MAINNET
if network not in (NEM_NETWORK_MAINNET, NEM_NETWORK_TESTNET, NEM_NETWORK_MIJIN): if network not in (NEM_NETWORK_MAINNET, NEM_NETWORK_TESTNET, NEM_NETWORK_MIJIN):
raise ProcessError('Invalid NEM network') raise ProcessError("Invalid NEM network")
return network return network
def _validate_single_tx(msg: NEMSignTx): def _validate_single_tx(msg: NEMSignTx):
# ensure exactly one transaction is provided # ensure exactly one transaction is provided
tx_count = \ tx_count = (
bool(msg.transfer) + \ bool(msg.transfer)
bool(msg.provision_namespace) + \ + bool(msg.provision_namespace)
bool(msg.mosaic_creation) + \ + bool(msg.mosaic_creation)
bool(msg.supply_change) + \ + bool(msg.supply_change)
bool(msg.aggregate_modification) + \ + bool(msg.aggregate_modification)
bool(msg.importance_transfer) + bool(msg.importance_transfer)
)
if tx_count == 0: if tx_count == 0:
raise ProcessError('No transaction provided') raise ProcessError("No transaction provided")
if tx_count > 1: if tx_count > 1:
raise ProcessError('More than one transaction provided') raise ProcessError("More than one transaction provided")
def _validate_common(common: NEMTransactionCommon, inner: bool = False): def _validate_common(common: NEMTransactionCommon, inner: bool = False):
@ -79,174 +82,196 @@ def _validate_common(common: NEMTransactionCommon, inner: bool = False):
err = None err = None
if common.timestamp is None: if common.timestamp is None:
err = 'timestamp' err = "timestamp"
if common.fee is None: if common.fee is None:
err = 'fee' err = "fee"
if common.deadline is None: if common.deadline is None:
err = 'deadline' err = "deadline"
if not inner and common.signer: if not inner and common.signer:
raise ProcessError('Signer not allowed in outer transaction') raise ProcessError("Signer not allowed in outer transaction")
if inner and common.signer is None: if inner and common.signer is None:
err = 'signer' err = "signer"
if err: if err:
if inner: if inner:
raise ProcessError('No %s provided in inner transaction' % err) raise ProcessError("No %s provided in inner transaction" % err)
else: else:
raise ProcessError('No %s provided' % err) raise ProcessError("No %s provided" % err)
if common.signer is not None: if common.signer is not None:
_validate_public_key(common.signer, 'Invalid signer public key in inner transaction') _validate_public_key(
common.signer, "Invalid signer public key in inner transaction"
)
def _validate_public_key(public_key: bytes, err_msg: str): def _validate_public_key(public_key: bytes, err_msg: str):
if not public_key: if not public_key:
raise ProcessError('%s (none provided)' % err_msg) raise ProcessError("%s (none provided)" % err_msg)
if len(public_key) != NEM_PUBLIC_KEY_SIZE: if len(public_key) != NEM_PUBLIC_KEY_SIZE:
raise ProcessError('%s (invalid length)' % err_msg) raise ProcessError("%s (invalid length)" % err_msg)
def _validate_importance_transfer(importance_transfer: NEMImportanceTransfer): def _validate_importance_transfer(importance_transfer: NEMImportanceTransfer):
if importance_transfer.mode is None: if importance_transfer.mode is None:
raise ProcessError('No mode provided') raise ProcessError("No mode provided")
_validate_public_key(importance_transfer.public_key, 'Invalid remote account public key provided') _validate_public_key(
importance_transfer.public_key, "Invalid remote account public key provided"
)
def _validate_multisig(multisig: NEMTransactionCommon, network: int): def _validate_multisig(multisig: NEMTransactionCommon, network: int):
if multisig.network != network: if multisig.network != network:
raise ProcessError('Inner transaction network is different') raise ProcessError("Inner transaction network is different")
_validate_public_key(multisig.signer, 'Invalid multisig signer public key provided') _validate_public_key(multisig.signer, "Invalid multisig signer public key provided")
def _validate_aggregate_modification( def _validate_aggregate_modification(
aggregate_modification: NEMAggregateModification, aggregate_modification: NEMAggregateModification, creation: bool = False
creation: bool = False): ):
if creation and not aggregate_modification.modifications: if creation and not aggregate_modification.modifications:
raise ProcessError('No modifications provided') raise ProcessError("No modifications provided")
for m in aggregate_modification.modifications: for m in aggregate_modification.modifications:
if not m.type: if not m.type:
raise ProcessError('No modification type provided') raise ProcessError("No modification type provided")
if m.type not in ( if m.type not in (
NEMModificationType.CosignatoryModification_Add, NEMModificationType.CosignatoryModification_Add,
NEMModificationType.CosignatoryModification_Delete NEMModificationType.CosignatoryModification_Delete,
): ):
raise ProcessError('Unknown aggregate modification') raise ProcessError("Unknown aggregate modification")
if creation and m.type == NEMModificationType.CosignatoryModification_Delete: if creation and m.type == NEMModificationType.CosignatoryModification_Delete:
raise ProcessError('Cannot remove cosignatory when converting account') raise ProcessError("Cannot remove cosignatory when converting account")
_validate_public_key(m.public_key, 'Invalid cosignatory public key provided') _validate_public_key(m.public_key, "Invalid cosignatory public key provided")
def _validate_supply_change(supply_change: NEMMosaicSupplyChange): def _validate_supply_change(supply_change: NEMMosaicSupplyChange):
if supply_change.namespace is None: if supply_change.namespace is None:
raise ProcessError('No namespace provided') raise ProcessError("No namespace provided")
if supply_change.mosaic is None: if supply_change.mosaic is None:
raise ProcessError('No mosaic provided') raise ProcessError("No mosaic provided")
if supply_change.type is None: if supply_change.type is None:
raise ProcessError('No type provided') raise ProcessError("No type provided")
elif supply_change.type not in [NEMSupplyChangeType.SupplyChange_Decrease, NEMSupplyChangeType.SupplyChange_Increase]: elif supply_change.type not in [
raise ProcessError('Invalid supply change type') NEMSupplyChangeType.SupplyChange_Decrease,
NEMSupplyChangeType.SupplyChange_Increase,
]:
raise ProcessError("Invalid supply change type")
if supply_change.delta is None: if supply_change.delta is None:
raise ProcessError('No delta provided') raise ProcessError("No delta provided")
def _validate_mosaic_creation(mosaic_creation: NEMMosaicCreation, network: int): def _validate_mosaic_creation(mosaic_creation: NEMMosaicCreation, network: int):
if mosaic_creation.definition is None: if mosaic_creation.definition is None:
raise ProcessError('No mosaic definition provided') raise ProcessError("No mosaic definition provided")
if mosaic_creation.sink is None: if mosaic_creation.sink is None:
raise ProcessError('No creation sink provided') raise ProcessError("No creation sink provided")
if mosaic_creation.fee is None: if mosaic_creation.fee is None:
raise ProcessError('No creation sink fee provided') raise ProcessError("No creation sink fee provided")
if not nem.validate_address(mosaic_creation.sink, network): if not nem.validate_address(mosaic_creation.sink, network):
raise ProcessError('Invalid creation sink address') raise ProcessError("Invalid creation sink address")
if mosaic_creation.definition.name is not None: if mosaic_creation.definition.name is not None:
raise ProcessError('Name not allowed in mosaic creation transactions') raise ProcessError("Name not allowed in mosaic creation transactions")
if mosaic_creation.definition.ticker is not None: if mosaic_creation.definition.ticker is not None:
raise ProcessError('Ticker not allowed in mosaic creation transactions') raise ProcessError("Ticker not allowed in mosaic creation transactions")
if mosaic_creation.definition.networks: if mosaic_creation.definition.networks:
raise ProcessError('Networks not allowed in mosaic creation transactions') raise ProcessError("Networks not allowed in mosaic creation transactions")
if mosaic_creation.definition.namespace is None: if mosaic_creation.definition.namespace is None:
raise ProcessError('No mosaic namespace provided') raise ProcessError("No mosaic namespace provided")
if mosaic_creation.definition.mosaic is None: if mosaic_creation.definition.mosaic is None:
raise ProcessError('No mosaic name provided') raise ProcessError("No mosaic name provided")
if mosaic_creation.definition.supply is not None and mosaic_creation.definition.divisibility is None: if (
raise ProcessError('Definition divisibility needs to be provided when supply is') mosaic_creation.definition.supply is not None
if mosaic_creation.definition.supply is None and mosaic_creation.definition.divisibility is not None: and mosaic_creation.definition.divisibility is None
raise ProcessError('Definition supply needs to be provided when divisibility is') ):
raise ProcessError(
"Definition divisibility needs to be provided when supply is"
)
if (
mosaic_creation.definition.supply is None
and mosaic_creation.definition.divisibility is not None
):
raise ProcessError(
"Definition supply needs to be provided when divisibility is"
)
if mosaic_creation.definition.levy is not None: if mosaic_creation.definition.levy is not None:
if mosaic_creation.definition.fee is None: if mosaic_creation.definition.fee is None:
raise ProcessError('No levy fee provided') raise ProcessError("No levy fee provided")
if mosaic_creation.definition.levy_address is None: if mosaic_creation.definition.levy_address is None:
raise ProcessError('No levy address provided') raise ProcessError("No levy address provided")
if mosaic_creation.definition.levy_namespace is None: if mosaic_creation.definition.levy_namespace is None:
raise ProcessError('No levy namespace provided') raise ProcessError("No levy namespace provided")
if mosaic_creation.definition.levy_mosaic is None: if mosaic_creation.definition.levy_mosaic is None:
raise ProcessError('No levy mosaic name provided') raise ProcessError("No levy mosaic name provided")
if mosaic_creation.definition.divisibility is None: if mosaic_creation.definition.divisibility is None:
raise ProcessError('No divisibility provided') raise ProcessError("No divisibility provided")
if mosaic_creation.definition.supply is None: if mosaic_creation.definition.supply is None:
raise ProcessError('No supply provided') raise ProcessError("No supply provided")
if mosaic_creation.definition.mutable_supply is None: if mosaic_creation.definition.mutable_supply is None:
raise ProcessError('No supply mutability provided') raise ProcessError("No supply mutability provided")
if mosaic_creation.definition.transferable is None: if mosaic_creation.definition.transferable is None:
raise ProcessError('No mosaic transferability provided') raise ProcessError("No mosaic transferability provided")
if mosaic_creation.definition.description is None: if mosaic_creation.definition.description is None:
raise ProcessError('No description provided') raise ProcessError("No description provided")
if mosaic_creation.definition.divisibility > NEM_MAX_DIVISIBILITY: if mosaic_creation.definition.divisibility > NEM_MAX_DIVISIBILITY:
raise ProcessError('Invalid divisibility provided') raise ProcessError("Invalid divisibility provided")
if mosaic_creation.definition.supply > NEM_MAX_SUPPLY: if mosaic_creation.definition.supply > NEM_MAX_SUPPLY:
raise ProcessError('Invalid supply provided') raise ProcessError("Invalid supply provided")
if not nem.validate_address(mosaic_creation.definition.levy_address, network): if not nem.validate_address(mosaic_creation.definition.levy_address, network):
raise ProcessError('Invalid levy address') raise ProcessError("Invalid levy address")
def _validate_provision_namespace(provision_namespace: NEMProvisionNamespace, network: int): def _validate_provision_namespace(
provision_namespace: NEMProvisionNamespace, network: int
):
if provision_namespace.namespace is None: if provision_namespace.namespace is None:
raise ProcessError('No namespace provided') raise ProcessError("No namespace provided")
if provision_namespace.sink is None: if provision_namespace.sink is None:
raise ProcessError('No rental sink provided') raise ProcessError("No rental sink provided")
if provision_namespace.fee is None: if provision_namespace.fee is None:
raise ProcessError('No rental sink fee provided') raise ProcessError("No rental sink fee provided")
if not nem.validate_address(provision_namespace.sink, network): if not nem.validate_address(provision_namespace.sink, network):
raise ProcessError('Invalid rental sink address') raise ProcessError("Invalid rental sink address")
def _validate_transfer(transfer: NEMTransfer, network: int): def _validate_transfer(transfer: NEMTransfer, network: int):
if transfer.recipient is None: if transfer.recipient is None:
raise ProcessError('No recipient provided') raise ProcessError("No recipient provided")
if transfer.amount is None: if transfer.amount is None:
raise ProcessError('No amount provided') raise ProcessError("No amount provided")
if transfer.public_key is not None: if transfer.public_key is not None:
_validate_public_key(transfer.public_key, 'Invalid recipient public key') _validate_public_key(transfer.public_key, "Invalid recipient public key")
if transfer.payload is None: if transfer.payload is None:
raise ProcessError('Public key provided but no payload to encrypt') raise ProcessError("Public key provided but no payload to encrypt")
if transfer.payload: if transfer.payload:
if len(transfer.payload) > NEM_MAX_PLAIN_PAYLOAD_SIZE: if len(transfer.payload) > NEM_MAX_PLAIN_PAYLOAD_SIZE:
raise ProcessError('Payload too large') raise ProcessError("Payload too large")
if transfer.public_key and len(transfer.payload) > NEM_MAX_ENCRYPTED_PAYLOAD_SIZE: if (
raise ProcessError('Payload too large') transfer.public_key
and len(transfer.payload) > NEM_MAX_ENCRYPTED_PAYLOAD_SIZE
):
raise ProcessError("Payload too large")
if not nem.validate_address(transfer.recipient, network): if not nem.validate_address(transfer.recipient, network):
raise ProcessError('Invalid recipient address') raise ProcessError("Invalid recipient address")
for m in transfer.mosaics: for m in transfer.mosaics:
if m.namespace is None: if m.namespace is None:
raise ProcessError('No mosaic namespace provided') raise ProcessError("No mosaic namespace provided")
if m.mosaic is None: if m.mosaic is None:
raise ProcessError('No mosaic name provided') raise ProcessError("No mosaic name provided")
if m.quantity is None: if m.quantity is None:
raise ProcessError('No mosaic quantity provided') raise ProcessError("No mosaic quantity provided")

@ -28,10 +28,12 @@ def write_bytes_with_length(w, buf: bytearray):
write_bytes(w, buf) write_bytes(w, buf)
def write_common(common: NEMTransactionCommon, def write_common(
public_key: bytearray, common: NEMTransactionCommon,
transaction_type: int, public_key: bytearray,
version: int = None) -> bytearray: transaction_type: int,
version: int = None,
) -> bytearray:
ret = bytearray() ret = bytearray()
write_uint32(ret, transaction_type) write_uint32(ret, transaction_type)

@ -14,46 +14,55 @@ from trezor.wire import protobuf_workflow, register
def dispatch_GetPublicKey(*args, **kwargs): def dispatch_GetPublicKey(*args, **kwargs):
from .get_public_key import get_public_key from .get_public_key import get_public_key
return get_public_key(*args, **kwargs) return get_public_key(*args, **kwargs)
def dispatch_GetAddress(*args, **kwargs): def dispatch_GetAddress(*args, **kwargs):
from .get_address import get_address from .get_address import get_address
return get_address(*args, **kwargs) return get_address(*args, **kwargs)
def dispatch_GetEntropy(*args, **kwargs): def dispatch_GetEntropy(*args, **kwargs):
from .get_entropy import get_entropy from .get_entropy import get_entropy
return get_entropy(*args, **kwargs) return get_entropy(*args, **kwargs)
def dispatch_SignTx(*args, **kwargs): def dispatch_SignTx(*args, **kwargs):
from .sign_tx import sign_tx from .sign_tx import sign_tx
return sign_tx(*args, **kwargs) return sign_tx(*args, **kwargs)
def dispatch_SignMessage(*args, **kwargs): def dispatch_SignMessage(*args, **kwargs):
from .sign_message import sign_message from .sign_message import sign_message
return sign_message(*args, **kwargs) return sign_message(*args, **kwargs)
def dispatch_VerifyMessage(*args, **kwargs): def dispatch_VerifyMessage(*args, **kwargs):
from .verify_message import verify_message from .verify_message import verify_message
return verify_message(*args, **kwargs) return verify_message(*args, **kwargs)
def dispatch_SignIdentity(*args, **kwargs): def dispatch_SignIdentity(*args, **kwargs):
from .sign_identity import sign_identity from .sign_identity import sign_identity
return sign_identity(*args, **kwargs) return sign_identity(*args, **kwargs)
def dispatch_GetECDHSessionKey(*args, **kwargs): def dispatch_GetECDHSessionKey(*args, **kwargs):
from .ecdh import get_ecdh_session_key from .ecdh import get_ecdh_session_key
return get_ecdh_session_key(*args, **kwargs) return get_ecdh_session_key(*args, **kwargs)
def dispatch_CipherKeyValue(*args, **kwargs): def dispatch_CipherKeyValue(*args, **kwargs):
from .cipher_key_value import cipher_key_value from .cipher_key_value import cipher_key_value
return cipher_key_value(*args, **kwargs) return cipher_key_value(*args, **kwargs)

@ -11,15 +11,15 @@ from apps.common.confirm import require_confirm
async def cipher_key_value(ctx, msg): async def cipher_key_value(ctx, msg):
if len(msg.value) % 16 > 0: if len(msg.value) % 16 > 0:
raise wire.DataError('Value length must be a multiple of 16') raise wire.DataError("Value length must be a multiple of 16")
encrypt = msg.encrypt encrypt = msg.encrypt
decrypt = not msg.encrypt decrypt = not msg.encrypt
if (encrypt and msg.ask_on_encrypt) or (decrypt and msg.ask_on_decrypt): if (encrypt and msg.ask_on_encrypt) or (decrypt and msg.ask_on_decrypt):
if encrypt: if encrypt:
title = 'Encrypt value' title = "Encrypt value"
else: else:
title = 'Decrypt value' title = "Decrypt value"
text = Text(title) text = Text(title)
text.normal(msg.key) text.normal(msg.key)
await require_confirm(ctx, text) await require_confirm(ctx, text)
@ -31,8 +31,8 @@ async def cipher_key_value(ctx, msg):
def compute_cipher_key_value(msg, seckey: bytes) -> bytes: def compute_cipher_key_value(msg, seckey: bytes) -> bytes:
data = msg.key data = msg.key
data += 'E1' if msg.ask_on_encrypt else 'E0' data += "E1" if msg.ask_on_encrypt else "E0"
data += 'D1' if msg.ask_on_decrypt else 'D0' data += "D1" if msg.ask_on_decrypt else "D0"
data = hmac.new(seckey, data, sha512).digest() data = hmac.new(seckey, data, sha512).digest()
key = data[:32] key = data[:32]
if msg.iv and len(msg.iv) == 16: if msg.iv and len(msg.iv) == 16:

@ -15,7 +15,7 @@ from apps.wallet.sign_identity import (
async def get_ecdh_session_key(ctx, msg): async def get_ecdh_session_key(ctx, msg):
if msg.ecdsa_curve_name is None: if msg.ecdsa_curve_name is None:
msg.ecdsa_curve_name = 'secp256k1' msg.ecdsa_curve_name = "secp256k1"
identity = serialize_identity(msg.identity) identity = serialize_identity(msg.identity)
@ -24,42 +24,47 @@ async def get_ecdh_session_key(ctx, msg):
address_n = get_ecdh_path(identity, msg.identity.index or 0) address_n = get_ecdh_path(identity, msg.identity.index or 0)
node = await seed.derive_node(ctx, address_n, msg.ecdsa_curve_name) node = await seed.derive_node(ctx, address_n, msg.ecdsa_curve_name)
session_key = ecdh(seckey=node.private_key(), session_key = ecdh(
peer_public_key=msg.peer_public_key, seckey=node.private_key(),
curve=msg.ecdsa_curve_name) peer_public_key=msg.peer_public_key,
curve=msg.ecdsa_curve_name,
)
return ECDHSessionKey(session_key=session_key) return ECDHSessionKey(session_key=session_key)
async def require_confirm_ecdh_session_key(ctx, identity): async def require_confirm_ecdh_session_key(ctx, identity):
lines = chunks(serialize_identity_without_proto(identity), 18) lines = chunks(serialize_identity_without_proto(identity), 18)
proto = identity.proto.upper() if identity.proto else 'identity' proto = identity.proto.upper() if identity.proto else "identity"
text = Text('Decrypt %s' % proto) text = Text("Decrypt %s" % proto)
text.mono(*lines) text.mono(*lines)
await require_confirm(ctx, text) await require_confirm(ctx, text)
def get_ecdh_path(identity: str, index: int): def get_ecdh_path(identity: str, index: int):
identity_hash = sha256(pack('<I', index) + identity).digest() identity_hash = sha256(pack("<I", index) + identity).digest()
address_n = (17, ) + unpack('<IIII', identity_hash[:16]) address_n = (17,) + unpack("<IIII", identity_hash[:16])
address_n = [HARDENED | x for x in address_n] address_n = [HARDENED | x for x in address_n]
return address_n return address_n
def ecdh(seckey: bytes, peer_public_key: bytes, curve: str) -> bytes: def ecdh(seckey: bytes, peer_public_key: bytes, curve: str) -> bytes:
if curve == 'secp256k1': if curve == "secp256k1":
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
session_key = secp256k1.multiply(seckey, peer_public_key) session_key = secp256k1.multiply(seckey, peer_public_key)
elif curve == 'nist256p1': elif curve == "nist256p1":
from trezor.crypto.curve import nist256p1 from trezor.crypto.curve import nist256p1
session_key = nist256p1.multiply(seckey, peer_public_key) session_key = nist256p1.multiply(seckey, peer_public_key)
elif curve == 'curve25519': elif curve == "curve25519":
from trezor.crypto.curve import curve25519 from trezor.crypto.curve import curve25519
if peer_public_key[0] != 0x40: if peer_public_key[0] != 0x40:
raise ValueError('Curve25519 public key should start with 0x40') raise ValueError("Curve25519 public key should start with 0x40")
session_key = b'\x04' + curve25519.multiply(seckey, peer_public_key[1:]) session_key = b"\x04" + curve25519.multiply(seckey, peer_public_key[1:])
else: else:
raise ValueError('Unsupported curve for ECDH: ' + curve) raise ValueError("Unsupported curve for ECDH: " + curve)
return session_key return session_key

@ -7,7 +7,7 @@ from apps.wallet.sign_tx import addresses
async def get_address(ctx, msg): async def get_address(ctx, msg):
coin_name = msg.coin_name or 'Bitcoin' coin_name = msg.coin_name or "Bitcoin"
coin = coins.by_name(coin_name) coin = coins.by_name(coin_name)
node = await seed.derive_node(ctx, msg.address_n, curve_name=coin.curve_name) node = await seed.derive_node(ctx, msg.address_n, curve_name=coin.curve_name)
@ -18,7 +18,12 @@ async def get_address(ctx, msg):
while True: while True:
if await show_address(ctx, address_short): if await show_address(ctx, address_short):
break break
if await show_qr(ctx, address.upper() if msg.script_type == InputScriptType.SPENDWITNESS else address): if await show_qr(
ctx,
address.upper()
if msg.script_type == InputScriptType.SPENDWITNESS
else address,
):
break break
return Address(address=address) return Address(address=address)

@ -8,9 +8,9 @@ from apps.common.confirm import require_confirm
async def get_entropy(ctx, msg): async def get_entropy(ctx, msg):
text = Text('Confirm entropy') text = Text("Confirm entropy")
text.bold('Do you really want', 'to send entropy?') text.bold("Do you really want", "to send entropy?")
text.normal('Continue only if you', 'know what you are doing!') text.normal("Continue only if you", "know what you are doing!")
await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall) await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall)
size = min(msg.size, 1024) size = min(msg.size, 1024)

@ -12,7 +12,7 @@ from apps.common.confirm import require_confirm
async def get_public_key(ctx, msg): async def get_public_key(ctx, msg):
coin_name = msg.coin_name or 'Bitcoin' coin_name = msg.coin_name or "Bitcoin"
coin = coins.by_name(coin_name) coin = coins.by_name(coin_name)
curve_name = msg.ecdsa_curve_name curve_name = msg.ecdsa_curve_name
@ -23,13 +23,14 @@ async def get_public_key(ctx, msg):
node_xpub = node.serialize_public(coin.xpub_magic) node_xpub = node.serialize_public(coin.xpub_magic)
pubkey = node.public_key() pubkey = node.public_key()
if pubkey[0] == 1: if pubkey[0] == 1:
pubkey = b'\x00' + pubkey[1:] pubkey = b"\x00" + pubkey[1:]
node_type = HDNodeType( node_type = HDNodeType(
depth=node.depth(), depth=node.depth(),
child_num=node.child_num(), child_num=node.child_num(),
fingerprint=node.fingerprint(), fingerprint=node.fingerprint(),
chain_code=node.chain_code(), chain_code=node.chain_code(),
public_key=pubkey) public_key=pubkey,
)
if msg.show_display: if msg.show_display:
await _show_pubkey(ctx, pubkey) await _show_pubkey(ctx, pubkey)
@ -39,9 +40,6 @@ async def get_public_key(ctx, msg):
async def _show_pubkey(ctx, pubkey: bytes): async def _show_pubkey(ctx, pubkey: bytes):
lines = chunks(hexlify(pubkey).decode(), 18) lines = chunks(hexlify(pubkey).decode(), 18)
text = Text('Confirm public key', ui.ICON_RECEIVE, icon_color=ui.GREEN) text = Text("Confirm public key", ui.ICON_RECEIVE, icon_color=ui.GREEN)
text.mono(*lines) text.mono(*lines)
return await require_confirm( return await require_confirm(ctx, text, code=ButtonRequestType.PublicKey)
ctx,
text,
code=ButtonRequestType.PublicKey)

@ -12,7 +12,7 @@ from apps.common.confirm import require_confirm
async def sign_identity(ctx, msg): async def sign_identity(ctx, msg):
if msg.ecdsa_curve_name is None: if msg.ecdsa_curve_name is None:
msg.ecdsa_curve_name = 'secp256k1' msg.ecdsa_curve_name = "secp256k1"
identity = serialize_identity(msg.identity) identity = serialize_identity(msg.identity)
@ -21,25 +21,40 @@ async def sign_identity(ctx, msg):
address_n = get_identity_path(identity, msg.identity.index or 0) address_n = get_identity_path(identity, msg.identity.index or 0)
node = await seed.derive_node(ctx, address_n, msg.ecdsa_curve_name) node = await seed.derive_node(ctx, address_n, msg.ecdsa_curve_name)
coin = coins.by_name('Bitcoin') coin = coins.by_name("Bitcoin")
if msg.ecdsa_curve_name == 'secp256k1': if msg.ecdsa_curve_name == "secp256k1":
address = node.address(coin.address_type) # hardcoded bitcoin address type address = node.address(coin.address_type) # hardcoded bitcoin address type
else: else:
address = None address = None
pubkey = node.public_key() pubkey = node.public_key()
if pubkey[0] == 0x01: if pubkey[0] == 0x01:
pubkey = b'\x00' + pubkey[1:] pubkey = b"\x00" + pubkey[1:]
seckey = node.private_key() seckey = node.private_key()
if msg.identity.proto == 'gpg': if msg.identity.proto == "gpg":
signature = sign_challenge( signature = sign_challenge(
seckey, msg.challenge_hidden, msg.challenge_visual, 'gpg', msg.ecdsa_curve_name) seckey,
elif msg.identity.proto == 'ssh': msg.challenge_hidden,
msg.challenge_visual,
"gpg",
msg.ecdsa_curve_name,
)
elif msg.identity.proto == "ssh":
signature = sign_challenge( signature = sign_challenge(
seckey, msg.challenge_hidden, msg.challenge_visual, 'ssh', msg.ecdsa_curve_name) seckey,
msg.challenge_hidden,
msg.challenge_visual,
"ssh",
msg.ecdsa_curve_name,
)
else: else:
signature = sign_challenge( signature = sign_challenge(
seckey, msg.challenge_hidden, msg.challenge_visual, coin, msg.ecdsa_curve_name) seckey,
msg.challenge_hidden,
msg.challenge_visual,
coin,
msg.ecdsa_curve_name,
)
return SignedIdentity(address=address, public_key=pubkey, signature=signature) return SignedIdentity(address=address, public_key=pubkey, signature=signature)
@ -52,22 +67,22 @@ async def require_confirm_sign_identity(ctx, identity, challenge_visual):
lines.append(ui.MONO) lines.append(ui.MONO)
lines.extend(chunks(serialize_identity_without_proto(identity), 18)) lines.extend(chunks(serialize_identity_without_proto(identity), 18))
proto = identity.proto.upper() if identity.proto else 'identity' proto = identity.proto.upper() if identity.proto else "identity"
text = Text('Sign %s' % proto) text = Text("Sign %s" % proto)
text.normal(*lines) text.normal(*lines)
await require_confirm(ctx, text) await require_confirm(ctx, text)
def serialize_identity(identity): def serialize_identity(identity):
s = '' s = ""
if identity.proto: if identity.proto:
s += identity.proto + '://' s += identity.proto + "://"
if identity.user: if identity.user:
s += identity.user + '@' s += identity.user + "@"
if identity.host: if identity.host:
s += identity.host s += identity.host
if identity.port: if identity.port:
s += ':' + identity.port s += ":" + identity.port
if identity.path: if identity.path:
s += identity.path s += identity.path
return s return s
@ -82,52 +97,53 @@ def serialize_identity_without_proto(identity):
def get_identity_path(identity: str, index: int): def get_identity_path(identity: str, index: int):
identity_hash = sha256(pack('<I', index) + identity).digest() identity_hash = sha256(pack("<I", index) + identity).digest()
address_n = (13, ) + unpack('<IIII', identity_hash[:16]) address_n = (13,) + unpack("<IIII", identity_hash[:16])
address_n = [HARDENED | x for x in address_n] address_n = [HARDENED | x for x in address_n]
return address_n return address_n
def sign_challenge(seckey: bytes, def sign_challenge(
challenge_hidden: bytes, seckey: bytes, challenge_hidden: bytes, challenge_visual: str, sigtype, curve: str
challenge_visual: str, ) -> bytes:
sigtype,
curve: str) -> bytes:
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
if curve == 'secp256k1':
if curve == "secp256k1":
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
elif curve == 'nist256p1': elif curve == "nist256p1":
from trezor.crypto.curve import nist256p1 from trezor.crypto.curve import nist256p1
elif curve == 'ed25519': elif curve == "ed25519":
from trezor.crypto.curve import ed25519 from trezor.crypto.curve import ed25519
from apps.common.signverify import message_digest from apps.common.signverify import message_digest
if sigtype == 'gpg': if sigtype == "gpg":
data = challenge_hidden data = challenge_hidden
elif sigtype == 'ssh': elif sigtype == "ssh":
if curve != 'ed25519': if curve != "ed25519":
data = sha256(challenge_hidden).digest() data = sha256(challenge_hidden).digest()
else: else:
data = challenge_hidden data = challenge_hidden
else: else:
# sigtype is coin # sigtype is coin
challenge = sha256(challenge_hidden).digest() + sha256(challenge_visual).digest() challenge = (
sha256(challenge_hidden).digest() + sha256(challenge_visual).digest()
)
data = message_digest(sigtype, challenge) data = message_digest(sigtype, challenge)
if curve == 'secp256k1': if curve == "secp256k1":
signature = secp256k1.sign(seckey, data) signature = secp256k1.sign(seckey, data)
elif curve == 'nist256p1': elif curve == "nist256p1":
signature = nist256p1.sign(seckey, data) signature = nist256p1.sign(seckey, data)
elif curve == 'ed25519': elif curve == "ed25519":
signature = ed25519.sign(seckey, data) signature = ed25519.sign(seckey, data)
else: else:
raise ValueError('Unknown curve') raise ValueError("Unknown curve")
if curve == 'ed25519': if curve == "ed25519":
signature = b'\x00' + signature signature = b"\x00" + signature
elif sigtype == 'gpg' or sigtype == 'ssh': elif sigtype == "gpg" or sigtype == "ssh":
signature = b'\x00' + signature[1:] signature = b"\x00" + signature[1:]
return signature return signature

@ -13,7 +13,7 @@ from apps.wallet.sign_tx.addresses import get_address
async def sign_message(ctx, msg): async def sign_message(ctx, msg):
message = msg.message message = msg.message
address_n = msg.address_n address_n = msg.address_n
coin_name = msg.coin_name or 'Bitcoin' coin_name = msg.coin_name or "Bitcoin"
script_type = msg.script_type or 0 script_type = msg.script_type or 0
coin = coins.by_name(coin_name) coin = coins.by_name(coin_name)
@ -33,13 +33,13 @@ async def sign_message(ctx, msg):
elif script_type == SPENDWITNESS: elif script_type == SPENDWITNESS:
signature = bytes([signature[0] + 8]) + signature[1:] signature = bytes([signature[0] + 8]) + signature[1:]
else: else:
raise wire.ProcessError('Unsupported script type') raise wire.ProcessError("Unsupported script type")
return MessageSignature(address=address, signature=signature) return MessageSignature(address=address, signature=signature)
async def require_confirm_sign_message(ctx, message): async def require_confirm_sign_message(ctx, message):
message = split_message(message) message = split_message(message)
text = Text('Sign message') text = Text("Sign message")
text.normal(*message) text.normal(*message)
await require_confirm(ctx, text) await require_confirm(ctx, text)

@ -22,18 +22,23 @@ class AddressError(Exception):
pass pass
def get_address(script_type: InputScriptType, coin: CoinInfo, node, multisig=None) -> str: def get_address(
script_type: InputScriptType, coin: CoinInfo, node, multisig=None
if script_type == InputScriptType.SPENDADDRESS or script_type == InputScriptType.SPENDMULTISIG: ) -> str:
if (
script_type == InputScriptType.SPENDADDRESS
or script_type == InputScriptType.SPENDMULTISIG
):
if multisig: # p2sh multisig if multisig: # p2sh multisig
pubkey = node.public_key() pubkey = node.public_key()
index = multisig_pubkey_index(multisig, pubkey) index = multisig_pubkey_index(multisig, pubkey)
if index is None: if index is None:
raise AddressError(FailureType.ProcessError, raise AddressError(FailureType.ProcessError, "Public key not found")
'Public key not found')
if coin.address_type_p2sh is None: if coin.address_type_p2sh is None:
raise AddressError(FailureType.ProcessError, raise AddressError(
'Multisig not enabled on this coin') FailureType.ProcessError, "Multisig not enabled on this coin"
)
pubkeys = multisig_get_pubkeys(multisig) pubkeys = multisig_get_pubkeys(multisig)
address = address_multisig_p2sh(pubkeys, multisig.m, coin) address = address_multisig_p2sh(pubkeys, multisig.m, coin)
@ -41,8 +46,7 @@ def get_address(script_type: InputScriptType, coin: CoinInfo, node, multisig=Non
address = address_to_cashaddr(address, coin) address = address_to_cashaddr(address, coin)
return address return address
if script_type == InputScriptType.SPENDMULTISIG: if script_type == InputScriptType.SPENDMULTISIG:
raise AddressError(FailureType.ProcessError, raise AddressError(FailureType.ProcessError, "Multisig details required")
'Multisig details required')
# p2pkh # p2pkh
address = node.address(coin.address_type) address = node.address(coin.address_type)
@ -52,8 +56,9 @@ def get_address(script_type: InputScriptType, coin: CoinInfo, node, multisig=Non
elif script_type == InputScriptType.SPENDWITNESS: # native p2wpkh or native p2wsh elif script_type == InputScriptType.SPENDWITNESS: # native p2wpkh or native p2wsh
if not coin.segwit or not coin.bech32_prefix: if not coin.segwit or not coin.bech32_prefix:
raise AddressError(FailureType.ProcessError, raise AddressError(
'Segwit not enabled on this coin') FailureType.ProcessError, "Segwit not enabled on this coin"
)
# native p2wsh multisig # native p2wsh multisig
if multisig is not None: if multisig is not None:
pubkeys = multisig_get_pubkeys(multisig) pubkeys = multisig_get_pubkeys(multisig)
@ -62,10 +67,13 @@ def get_address(script_type: InputScriptType, coin: CoinInfo, node, multisig=Non
# native p2wpkh # native p2wpkh
return address_p2wpkh(node.public_key(), coin.bech32_prefix) return address_p2wpkh(node.public_key(), coin.bech32_prefix)
elif script_type == InputScriptType.SPENDP2SHWITNESS: # p2wpkh or p2wsh nested in p2sh elif (
script_type == InputScriptType.SPENDP2SHWITNESS
): # p2wpkh or p2wsh nested in p2sh
if not coin.segwit or coin.address_type_p2sh is None: if not coin.segwit or coin.address_type_p2sh is None:
raise AddressError(FailureType.ProcessError, raise AddressError(
'Segwit not enabled on this coin') FailureType.ProcessError, "Segwit not enabled on this coin"
)
# p2wsh multisig nested in p2sh # p2wsh multisig nested in p2sh
if multisig is not None: if multisig is not None:
pubkeys = multisig_get_pubkeys(multisig) pubkeys = multisig_get_pubkeys(multisig)
@ -75,14 +83,14 @@ def get_address(script_type: InputScriptType, coin: CoinInfo, node, multisig=Non
return address_p2wpkh_in_p2sh(node.public_key(), coin) return address_p2wpkh_in_p2sh(node.public_key(), coin)
else: else:
raise AddressError(FailureType.ProcessError, raise AddressError(FailureType.ProcessError, "Invalid script type")
'Invalid script type')
def address_multisig_p2sh(pubkeys: bytes, m: int, coin: CoinInfo): def address_multisig_p2sh(pubkeys: bytes, m: int, coin: CoinInfo):
if coin.address_type_p2sh is None: if coin.address_type_p2sh is None:
raise AddressError(FailureType.ProcessError, raise AddressError(
'Multisig not enabled on this coin') FailureType.ProcessError, "Multisig not enabled on this coin"
)
redeem_script = output_script_multisig(pubkeys, m) redeem_script = output_script_multisig(pubkeys, m)
redeem_script_hash = sha256_ripemd160_digest(redeem_script) redeem_script_hash = sha256_ripemd160_digest(redeem_script)
return address_p2sh(redeem_script_hash, coin) return address_p2sh(redeem_script_hash, coin)
@ -90,8 +98,9 @@ def address_multisig_p2sh(pubkeys: bytes, m: int, coin: CoinInfo):
def address_multisig_p2wsh_in_p2sh(pubkeys: bytes, m: int, coin: CoinInfo): def address_multisig_p2wsh_in_p2sh(pubkeys: bytes, m: int, coin: CoinInfo):
if coin.address_type_p2sh is None: if coin.address_type_p2sh is None:
raise AddressError(FailureType.ProcessError, raise AddressError(
'Multisig not enabled on this coin') FailureType.ProcessError, "Multisig not enabled on this coin"
)
witness_script = output_script_multisig(pubkeys, m) witness_script = output_script_multisig(pubkeys, m)
witness_script_hash = sha256(witness_script).digest() witness_script_hash = sha256(witness_script).digest()
return address_p2wsh_in_p2sh(witness_script_hash, coin) return address_p2wsh_in_p2sh(witness_script_hash, coin)
@ -99,8 +108,9 @@ def address_multisig_p2wsh_in_p2sh(pubkeys: bytes, m: int, coin: CoinInfo):
def address_multisig_p2wsh(pubkeys: bytes, m: int, hrp: str): def address_multisig_p2wsh(pubkeys: bytes, m: int, hrp: str):
if not hrp: if not hrp:
raise AddressError(FailureType.ProcessError, raise AddressError(
'Multisig not enabled on this coin') FailureType.ProcessError, "Multisig not enabled on this coin"
)
witness_script = output_script_multisig(pubkeys, m) witness_script = output_script_multisig(pubkeys, m)
witness_script_hash = sha256(witness_script).digest() witness_script_hash = sha256(witness_script).digest()
return address_p2wsh(witness_script_hash, hrp) return address_p2wsh(witness_script_hash, hrp)
@ -133,24 +143,21 @@ def address_p2wpkh(pubkey: bytes, hrp: str) -> str:
pubkeyhash = ecdsa_hash_pubkey(pubkey) pubkeyhash = ecdsa_hash_pubkey(pubkey)
address = bech32.encode(hrp, _BECH32_WITVER, pubkeyhash) address = bech32.encode(hrp, _BECH32_WITVER, pubkeyhash)
if address is None: if address is None:
raise AddressError(FailureType.ProcessError, raise AddressError(FailureType.ProcessError, "Invalid address")
'Invalid address')
return address return address
def address_p2wsh(witness_script_hash: bytes, hrp: str) -> str: def address_p2wsh(witness_script_hash: bytes, hrp: str) -> str:
address = bech32.encode(hrp, _BECH32_WITVER, witness_script_hash) address = bech32.encode(hrp, _BECH32_WITVER, witness_script_hash)
if address is None: if address is None:
raise AddressError(FailureType.ProcessError, raise AddressError(FailureType.ProcessError, "Invalid address")
'Invalid address')
return address return address
def decode_bech32_address(prefix: str, address: str) -> bytes: def decode_bech32_address(prefix: str, address: str) -> bytes:
witver, raw = bech32.decode(prefix, address) witver, raw = bech32.decode(prefix, address)
if witver != _BECH32_WITVER: if witver != _BECH32_WITVER:
raise AddressError(FailureType.ProcessError, raise AddressError(FailureType.ProcessError, "Invalid address witness program")
'Invalid address witness program')
return bytes(raw) return bytes(raw)
@ -162,7 +169,7 @@ def address_to_cashaddr(address: str, coin: CoinInfo) -> str:
elif version == coin.address_type_p2sh: elif version == coin.address_type_p2sh:
version = cashaddr.ADDRESS_TYPE_P2SH version = cashaddr.ADDRESS_TYPE_P2SH
else: else:
raise ValueError('Unknown cashaddr address type') raise ValueError("Unknown cashaddr address type")
return cashaddr.encode(coin.cashaddr_prefix, version, data) return cashaddr.encode(coin.cashaddr_prefix, version, data)
@ -170,7 +177,7 @@ def ecdsa_hash_pubkey(pubkey: bytes) -> bytes:
if pubkey[0] == 0x04: if pubkey[0] == 0x04:
ensure(len(pubkey) == 65) # uncompressed format ensure(len(pubkey) == 65) # uncompressed format
elif pubkey[0] == 0x00: elif pubkey[0] == 0x00:
ensure(len(pubkey) == 1) # point at infinity ensure(len(pubkey) == 1) # point at infinity
else: else:
ensure(len(pubkey) == 33) # compresssed format ensure(len(pubkey) == 33) # compresssed format
h = sha256(pubkey).digest() h = sha256(pubkey).digest()
@ -179,7 +186,9 @@ def ecdsa_hash_pubkey(pubkey: bytes) -> bytes:
def address_short(coin: CoinInfo, address: str) -> str: def address_short(coin: CoinInfo, address: str) -> str:
if coin.cashaddr_prefix is not None and address.startswith(coin.cashaddr_prefix + ':'): if coin.cashaddr_prefix is not None and address.startswith(
return address[len(coin.cashaddr_prefix) + 1:] coin.cashaddr_prefix + ":"
):
return address[len(coin.cashaddr_prefix) + 1 :]
else: else:
return address return address

@ -20,14 +20,12 @@ from apps.common.coininfo import CoinInfo
class UiConfirmOutput: class UiConfirmOutput:
def __init__(self, output: TxOutputType, coin: CoinInfo): def __init__(self, output: TxOutputType, coin: CoinInfo):
self.output = output self.output = output
self.coin = coin self.coin = coin
class UiConfirmTotal: class UiConfirmTotal:
def __init__(self, spending: int, fee: int, coin: CoinInfo): def __init__(self, spending: int, fee: int, coin: CoinInfo):
self.spending = spending self.spending = spending
self.fee = fee self.fee = fee
@ -35,14 +33,12 @@ class UiConfirmTotal:
class UiConfirmFeeOverThreshold: class UiConfirmFeeOverThreshold:
def __init__(self, fee: int, coin: CoinInfo): def __init__(self, fee: int, coin: CoinInfo):
self.fee = fee self.fee = fee
self.coin = coin self.coin = coin
class UiConfirmForeignAddress: class UiConfirmForeignAddress:
def __init__(self, address_n: list, coin: CoinInfo): def __init__(self, address_n: list, coin: CoinInfo):
self.address_n = address_n self.address_n = address_n
self.coin = coin self.coin = coin
@ -64,7 +60,7 @@ def confirm_foreign_address(address_n: list, coin: CoinInfo):
return (yield UiConfirmForeignAddress(address_n, coin)) return (yield UiConfirmForeignAddress(address_n, coin))
def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None): def request_tx_meta(tx_req: TxRequest, tx_hash: bytes = None):
tx_req.request_type = TXMETA tx_req.request_type = TXMETA
tx_req.details.tx_hash = tx_hash tx_req.details.tx_hash = tx_hash
tx_req.details.request_index = None tx_req.details.request_index = None
@ -73,7 +69,9 @@ def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None):
return sanitize_tx_meta(ack.tx) return sanitize_tx_meta(ack.tx)
def request_tx_extra_data(tx_req: TxRequest, offset: int, size: int, tx_hash: bytes=None): def request_tx_extra_data(
tx_req: TxRequest, offset: int, size: int, tx_hash: bytes = None
):
tx_req.request_type = TXEXTRADATA tx_req.request_type = TXEXTRADATA
tx_req.details.extra_data_offset = offset tx_req.details.extra_data_offset = offset
tx_req.details.extra_data_len = size tx_req.details.extra_data_len = size
@ -84,7 +82,7 @@ def request_tx_extra_data(tx_req: TxRequest, offset: int, size: int, tx_hash: by
return ack.tx.extra_data return ack.tx.extra_data
def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None): def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes = None):
tx_req.request_type = TXINPUT tx_req.request_type = TXINPUT
tx_req.details.request_index = i tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash tx_req.details.tx_hash = tx_hash
@ -93,7 +91,7 @@ def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
return sanitize_tx_input(ack.tx) return sanitize_tx_input(ack.tx)
def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None): def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes = None):
tx_req.request_type = TXOUTPUT tx_req.request_type = TXOUTPUT
tx_req.details.request_index = i tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash tx_req.details.tx_hash = tx_hash
@ -121,7 +119,7 @@ def sanitize_sign_tx(tx: SignTx) -> SignTx:
tx.lock_time = tx.lock_time if tx.lock_time is not None else 0 tx.lock_time = tx.lock_time if tx.lock_time is not None else 0
tx.inputs_count = tx.inputs_count if tx.inputs_count is not None else 0 tx.inputs_count = tx.inputs_count if tx.inputs_count is not None else 0
tx.outputs_count = tx.outputs_count if tx.outputs_count is not None else 0 tx.outputs_count = tx.outputs_count if tx.outputs_count is not None else 0
tx.coin_name = tx.coin_name if tx.coin_name is not None else 'Bitcoin' tx.coin_name = tx.coin_name if tx.coin_name is not None else "Bitcoin"
tx.expiry = tx.expiry if tx.expiry is not None else 0 tx.expiry = tx.expiry if tx.expiry is not None else 0
tx.overwintered = tx.overwintered if tx.overwintered is not None else False tx.overwintered = tx.overwintered if tx.overwintered is not None else False
return tx return tx

@ -10,7 +10,7 @@ from apps.wallet.sign_tx import addresses
def format_coin_amount(amount, coin): def format_coin_amount(amount, coin):
return '%s %s' % (format_amount(amount, 8), coin.coin_shortcut) return "%s %s" % (format_amount(amount, 8), coin.coin_shortcut)
def split_address(address): def split_address(address):
@ -25,39 +25,36 @@ async def confirm_output(ctx, output, coin):
if output.script_type == OutputScriptType.PAYTOOPRETURN: if output.script_type == OutputScriptType.PAYTOOPRETURN:
data = hexlify(output.op_return_data).decode() data = hexlify(output.op_return_data).decode()
if len(data) >= 18 * 5: if len(data) >= 18 * 5:
data = data[:(18 * 5 - 3)] + '...' data = data[: (18 * 5 - 3)] + "..."
text = Text('OP_RETURN', ui.ICON_SEND, icon_color=ui.GREEN) text = Text("OP_RETURN", ui.ICON_SEND, icon_color=ui.GREEN)
text.mono(*split_op_return(data)) text.mono(*split_op_return(data))
else: else:
address = output.address address = output.address
address_short = addresses.address_short(coin, address) address_short = addresses.address_short(coin, address)
text = Text('Confirm sending', ui.ICON_SEND, icon_color=ui.GREEN) text = Text("Confirm sending", ui.ICON_SEND, icon_color=ui.GREEN)
text.normal(format_coin_amount(output.amount, coin) + ' to') text.normal(format_coin_amount(output.amount, coin) + " to")
text.mono(*split_address(address_short)) text.mono(*split_address(address_short))
return await confirm(ctx, text, ButtonRequestType.ConfirmOutput) return await confirm(ctx, text, ButtonRequestType.ConfirmOutput)
async def confirm_total(ctx, spending, fee, coin): async def confirm_total(ctx, spending, fee, coin):
text = Text('Confirm transaction', ui.ICON_SEND, icon_color=ui.GREEN) text = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
text.normal('Total amount:') text.normal("Total amount:")
text.bold(format_coin_amount(spending, coin)) text.bold(format_coin_amount(spending, coin))
text.normal('including fee:') text.normal("including fee:")
text.bold(format_coin_amount(fee, coin)) text.bold(format_coin_amount(fee, coin))
return await hold_to_confirm(ctx, text, ButtonRequestType.SignTx) return await hold_to_confirm(ctx, text, ButtonRequestType.SignTx)
async def confirm_feeoverthreshold(ctx, fee, coin): async def confirm_feeoverthreshold(ctx, fee, coin):
text = Text('High fee', ui.ICON_SEND, icon_color=ui.GREEN) text = Text("High fee", ui.ICON_SEND, icon_color=ui.GREEN)
text.normal('The fee of') text.normal("The fee of")
text.bold(format_coin_amount(fee, coin)) text.bold(format_coin_amount(fee, coin))
text.normal('is unexpectedly high.', 'Continue?') text.normal("is unexpectedly high.", "Continue?")
return await confirm(ctx, text, ButtonRequestType.FeeOverThreshold) return await confirm(ctx, text, ButtonRequestType.FeeOverThreshold)
async def confirm_foreign_address(ctx, address_n, coin): async def confirm_foreign_address(ctx, address_n, coin):
text = Text('Confirm sending', ui.ICON_SEND, icon_color=ui.RED) text = Text("Confirm sending", ui.ICON_SEND, icon_color=ui.RED)
text.normal( text.normal("Trying to spend", "coins from another chain.", "Continue?")
'Trying to spend',
'coins from another chain.',
'Continue?')
return await confirm(ctx, text, ButtonRequestType.SignTx) return await confirm(ctx, text, ButtonRequestType.SignTx)

@ -40,12 +40,12 @@ def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes:
n = len(pubkeys) n = len(pubkeys)
if n < 1 or n > 15 or m < 1 or m > 15: if n < 1 or n > 15 or m < 1 or m > 15:
raise MultisigError(FailureType.DataError, 'Invalid multisig parameters') raise MultisigError(FailureType.DataError, "Invalid multisig parameters")
for hd in pubkeys: for hd in pubkeys:
d = hd.node d = hd.node
if len(d.public_key) != 33 or len(d.chain_code) != 32: if len(d.public_key) != 33 or len(d.chain_code) != 32:
raise MultisigError(FailureType.DataError, 'Invalid multisig parameters') raise MultisigError(FailureType.DataError, "Invalid multisig parameters")
# casting to bytes(), sorting on bytearray() is not supported in MicroPython # casting to bytes(), sorting on bytearray() is not supported in MicroPython
pubkeys = sorted(pubkeys, key=lambda hd: bytes(hd.node.public_key)) pubkeys = sorted(pubkeys, key=lambda hd: bytes(hd.node.public_key))
@ -68,8 +68,7 @@ def multisig_pubkey_index(multisig: MultisigRedeemScriptType, pubkey: bytes) ->
for i, hd in enumerate(multisig.pubkeys): for i, hd in enumerate(multisig.pubkeys):
if multisig_get_pubkey(hd) == pubkey: if multisig_get_pubkey(hd) == pubkey:
return i return i
raise MultisigError(FailureType.DataError, raise MultisigError(FailureType.DataError, "Pubkey not found in multisig script")
'Pubkey not found in multisig script')
def multisig_get_pubkey(hd: HDNodePathType) -> bytes: def multisig_get_pubkey(hd: HDNodePathType) -> bytes:
@ -80,7 +79,8 @@ def multisig_get_pubkey(hd: HDNodePathType) -> bytes:
fingerprint=n.fingerprint, fingerprint=n.fingerprint,
child_num=n.child_num, child_num=n.child_num,
chain_code=n.chain_code, chain_code=n.chain_code,
public_key=n.public_key) public_key=n.public_key,
)
for i in p: for i in p:
node.derive(i, True) node.derive(i, True)
return node.public_key() return node.public_key()

@ -28,11 +28,10 @@ class Zip143Error(ValueError):
class Zip143: class Zip143:
def __init__(self): def __init__(self):
self.h_prevouts = HashWriter(blake2b, outlen=32, personal=b'ZcashPrevoutHash') self.h_prevouts = HashWriter(blake2b, outlen=32, personal=b"ZcashPrevoutHash")
self.h_sequence = HashWriter(blake2b, outlen=32, personal=b'ZcashSequencHash') self.h_sequence = HashWriter(blake2b, outlen=32, personal=b"ZcashSequencHash")
self.h_outputs = HashWriter(blake2b, outlen=32, personal=b'ZcashOutputsHash') self.h_outputs = HashWriter(blake2b, outlen=32, personal=b"ZcashOutputsHash")
def add_prevouts(self, txi: TxInputType): def add_prevouts(self, txi: TxInputType):
write_bytes_rev(self.h_prevouts, txi.prev_hash) write_bytes_rev(self.h_prevouts, txi.prev_hash)
@ -53,31 +52,42 @@ class Zip143:
def get_outputs_hash(self) -> bytes: def get_outputs_hash(self) -> bytes:
return get_tx_hash(self.h_outputs) return get_tx_hash(self.h_outputs)
def preimage_hash(self, coin: CoinInfo, tx: SignTx, txi: TxInputType, pubkeyhash: bytes, sighash: int) -> bytes: def preimage_hash(
h_preimage = HashWriter(blake2b, outlen=32, personal=b'ZcashSigHash\x19\x1b\xa8\x5b') # BRANCH_ID = 0x5ba81b19 self,
coin: CoinInfo,
tx: SignTx,
txi: TxInputType,
pubkeyhash: bytes,
sighash: int,
) -> bytes:
h_preimage = HashWriter(
blake2b, outlen=32, personal=b"ZcashSigHash\x19\x1b\xa8\x5b"
) # BRANCH_ID = 0x5ba81b19
assert tx.overwintered assert tx.overwintered
write_uint32(h_preimage, tx.version | OVERWINTERED) # 1. nVersion | fOverwintered write_uint32(
write_uint32(h_preimage, coin.version_group_id) # 2. nVersionGroupId h_preimage, tx.version | OVERWINTERED
) # 1. nVersion | fOverwintered
write_uint32(h_preimage, coin.version_group_id) # 2. nVersionGroupId
write_bytes(h_preimage, bytearray(self.get_prevouts_hash())) # 3. hashPrevouts write_bytes(h_preimage, bytearray(self.get_prevouts_hash())) # 3. hashPrevouts
write_bytes(h_preimage, bytearray(self.get_sequence_hash())) # 4. hashSequence write_bytes(h_preimage, bytearray(self.get_sequence_hash())) # 4. hashSequence
write_bytes(h_preimage, bytearray(self.get_outputs_hash())) # 5. hashOutputs write_bytes(h_preimage, bytearray(self.get_outputs_hash())) # 5. hashOutputs
write_bytes(h_preimage, b'\x00' * 32) # 6. hashJoinSplits write_bytes(h_preimage, b"\x00" * 32) # 6. hashJoinSplits
write_uint32(h_preimage, tx.lock_time) # 7. nLockTime write_uint32(h_preimage, tx.lock_time) # 7. nLockTime
write_uint32(h_preimage, tx.expiry) # 8. expiryHeight write_uint32(h_preimage, tx.expiry) # 8. expiryHeight
write_uint32(h_preimage, sighash) # 9. nHashType write_uint32(h_preimage, sighash) # 9. nHashType
write_bytes_rev(h_preimage, txi.prev_hash) # 10a. outpoint write_bytes_rev(h_preimage, txi.prev_hash) # 10a. outpoint
write_uint32(h_preimage, txi.prev_index) write_uint32(h_preimage, txi.prev_index)
script_code = self.derive_script_code(txi, pubkeyhash) # 10b. scriptCode script_code = self.derive_script_code(txi, pubkeyhash) # 10b. scriptCode
write_varint(h_preimage, len(script_code)) write_varint(h_preimage, len(script_code))
write_bytes(h_preimage, script_code) write_bytes(h_preimage, script_code)
write_uint64(h_preimage, txi.amount) # 10c. value write_uint64(h_preimage, txi.amount) # 10c. value
write_uint32(h_preimage, txi.sequence) # 10d. nSequence write_uint32(h_preimage, txi.sequence) # 10d. nSequence
return get_tx_hash(h_preimage) return get_tx_hash(h_preimage)
@ -86,12 +96,16 @@ class Zip143:
def derive_script_code(self, txi: TxInputType, pubkeyhash: bytes) -> bytearray: def derive_script_code(self, txi: TxInputType, pubkeyhash: bytes) -> bytearray:
if txi.multisig: if txi.multisig:
return output_script_multisig(multisig_get_pubkeys(txi.multisig), txi.multisig.m) return output_script_multisig(
multisig_get_pubkeys(txi.multisig), txi.multisig.m
)
p2pkh = txi.script_type == InputScriptType.SPENDADDRESS p2pkh = txi.script_type == InputScriptType.SPENDADDRESS
if p2pkh: if p2pkh:
return output_script_p2pkh(pubkeyhash) return output_script_p2pkh(pubkeyhash)
else: else:
raise Zip143Error(FailureType.DataError, raise Zip143Error(
'Unknown input script type for zip143 script code') FailureType.DataError,
"Unknown input script type for zip143 script code",
)

@ -20,7 +20,7 @@ def advance():
def report_init(): def report_init():
ui.display.clear() ui.display.clear()
ui.header('Signing transaction') ui.header("Signing transaction")
def report(): def report():

@ -51,7 +51,9 @@ def output_script_p2sh(scripthash: bytes) -> bytearray:
return s return s
def script_replay_protection_bip115(block_hash: bytes, block_height: int) -> bytearray: def script_replay_protection_bip115(
block_hash: bytes, block_height: bytes
) -> bytearray:
if block_hash is None or block_height is None: if block_hash is None or block_height is None:
return bytearray() return bytearray()
assert len(block_hash) == 32 assert len(block_hash) == 32

@ -24,7 +24,6 @@ class Bip143Error(ValueError):
class Bip143: class Bip143:
def __init__(self): def __init__(self):
self.h_prevouts = HashWriter(sha256) self.h_prevouts = HashWriter(sha256)
self.h_sequence = HashWriter(sha256) self.h_sequence = HashWriter(sha256)
@ -49,27 +48,34 @@ class Bip143:
def get_outputs_hash(self, coin: CoinInfo) -> bytes: def get_outputs_hash(self, coin: CoinInfo) -> bytes:
return get_tx_hash(self.h_outputs, double=coin.sign_hash_double) return get_tx_hash(self.h_outputs, double=coin.sign_hash_double)
def preimage_hash(self, coin: CoinInfo, tx: SignTx, txi: TxInputType, pubkeyhash: bytes, sighash: int) -> bytes: def preimage_hash(
self,
coin: CoinInfo,
tx: SignTx,
txi: TxInputType,
pubkeyhash: bytes,
sighash: int,
) -> bytes:
h_preimage = HashWriter(sha256) h_preimage = HashWriter(sha256)
assert not tx.overwintered assert not tx.overwintered
write_uint32(h_preimage, tx.version) # nVersion write_uint32(h_preimage, tx.version) # nVersion
write_bytes(h_preimage, bytearray(self.get_prevouts_hash(coin))) # hashPrevouts write_bytes(h_preimage, bytearray(self.get_prevouts_hash(coin))) # hashPrevouts
write_bytes(h_preimage, bytearray(self.get_sequence_hash(coin))) # hashSequence write_bytes(h_preimage, bytearray(self.get_sequence_hash(coin))) # hashSequence
write_bytes_rev(h_preimage, txi.prev_hash) # outpoint write_bytes_rev(h_preimage, txi.prev_hash) # outpoint
write_uint32(h_preimage, txi.prev_index) # outpoint write_uint32(h_preimage, txi.prev_index) # outpoint
script_code = self.derive_script_code(txi, pubkeyhash) # scriptCode script_code = self.derive_script_code(txi, pubkeyhash) # scriptCode
write_varint(h_preimage, len(script_code)) write_varint(h_preimage, len(script_code))
write_bytes(h_preimage, script_code) write_bytes(h_preimage, script_code)
write_uint64(h_preimage, txi.amount) # amount write_uint64(h_preimage, txi.amount) # amount
write_uint32(h_preimage, txi.sequence) # nSequence write_uint32(h_preimage, txi.sequence) # nSequence
write_bytes(h_preimage, bytearray(self.get_outputs_hash(coin))) # hashOutputs write_bytes(h_preimage, bytearray(self.get_outputs_hash(coin))) # hashOutputs
write_uint32(h_preimage, tx.lock_time) # nLockTime write_uint32(h_preimage, tx.lock_time) # nLockTime
write_uint32(h_preimage, sighash) # nHashType write_uint32(h_preimage, sighash) # nHashType
return get_tx_hash(h_preimage, double=coin.sign_hash_double) return get_tx_hash(h_preimage, double=coin.sign_hash_double)
@ -78,16 +84,22 @@ class Bip143:
def derive_script_code(self, txi: TxInputType, pubkeyhash: bytes) -> bytearray: def derive_script_code(self, txi: TxInputType, pubkeyhash: bytes) -> bytearray:
if txi.multisig: if txi.multisig:
return output_script_multisig(multisig_get_pubkeys(txi.multisig), txi.multisig.m) return output_script_multisig(
multisig_get_pubkeys(txi.multisig), txi.multisig.m
p2pkh = (txi.script_type == InputScriptType.SPENDWITNESS or )
txi.script_type == InputScriptType.SPENDP2SHWITNESS or
txi.script_type == InputScriptType.SPENDADDRESS) p2pkh = (
txi.script_type == InputScriptType.SPENDWITNESS
or txi.script_type == InputScriptType.SPENDP2SHWITNESS
or txi.script_type == InputScriptType.SPENDADDRESS
)
if p2pkh: if p2pkh:
# for p2wpkh in p2sh or native p2wpkh # for p2wpkh in p2sh or native p2wpkh
# the scriptCode is a classic p2pkh # the scriptCode is a classic p2pkh
return output_script_p2pkh(pubkeyhash) return output_script_p2pkh(pubkeyhash)
else: else:
raise Bip143Error(FailureType.DataError, raise Bip143Error(
'Unknown input script type for bip143 script code') FailureType.DataError,
"Unknown input script type for bip143 script code",
)

@ -94,35 +94,40 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
if txi.multisig: if txi.multisig:
multifp.add(txi.multisig) multifp.add(txi.multisig)
if txi.script_type in (InputScriptType.SPENDWITNESS, if txi.script_type in (
InputScriptType.SPENDP2SHWITNESS): InputScriptType.SPENDWITNESS,
InputScriptType.SPENDP2SHWITNESS,
):
if not coin.segwit: if not coin.segwit:
raise SigningError(FailureType.DataError, raise SigningError(
'Segwit not enabled on this coin') FailureType.DataError, "Segwit not enabled on this coin"
)
if not txi.amount: if not txi.amount:
raise SigningError(FailureType.DataError, raise SigningError(FailureType.DataError, "Segwit input without amount")
'Segwit input without amount')
segwit[i] = True segwit[i] = True
segwit_in += txi.amount segwit_in += txi.amount
total_in += txi.amount total_in += txi.amount
elif txi.script_type in (InputScriptType.SPENDADDRESS, elif txi.script_type in (
InputScriptType.SPENDMULTISIG): InputScriptType.SPENDADDRESS,
InputScriptType.SPENDMULTISIG,
):
if coin.force_bip143 or tx.overwintered: if coin.force_bip143 or tx.overwintered:
if not txi.amount: if not txi.amount:
raise SigningError(FailureType.DataError, raise SigningError(
'BIP/ZIP 143 input without amount') FailureType.DataError, "BIP/ZIP 143 input without amount"
)
segwit[i] = False segwit[i] = False
segwit_in += txi.amount segwit_in += txi.amount
total_in += txi.amount total_in += txi.amount
else: else:
segwit[i] = False segwit[i] = False
total_in += await get_prevtx_output_value( total_in += await get_prevtx_output_value(
coin, tx_req, txi.prev_hash, txi.prev_index) coin, tx_req, txi.prev_hash, txi.prev_index
)
else: else:
raise SigningError(FailureType.DataError, raise SigningError(FailureType.DataError, "Wrong input script type")
'Wrong input script type')
for o in range(tx.outputs_count): for o in range(tx.outputs_count):
# STAGE_REQUEST_3_OUTPUT # STAGE_REQUEST_3_OUTPUT
@ -135,8 +140,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
# output is change and does not need confirmation # output is change and does not need confirmation
change_out = txo.amount change_out = txo.amount
elif not await confirm_output(txo, coin): elif not await confirm_output(txo, coin):
raise SigningError(FailureType.ActionCancelled, raise SigningError(FailureType.ActionCancelled, "Output cancelled")
'Output cancelled')
write_tx_output(h_first, txo_bin) write_tx_output(h_first, txo_bin)
hash143.add_output(txo_bin) hash143.add_output(txo_bin)
@ -144,18 +148,15 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
fee = total_in - total_out fee = total_in - total_out
if fee < 0: if fee < 0:
raise SigningError(FailureType.NotEnoughFunds, raise SigningError(FailureType.NotEnoughFunds, "Not enough funds")
'Not enough funds')
# fee > (coin.maxfee per byte * tx size) # fee > (coin.maxfee per byte * tx size)
if fee > (coin.maxfee_kb / 1000) * (weight.get_total() / 4): if fee > (coin.maxfee_kb / 1000) * (weight.get_total() / 4):
if not await confirm_feeoverthreshold(fee, coin): if not await confirm_feeoverthreshold(fee, coin):
raise SigningError(FailureType.ActionCancelled, raise SigningError(FailureType.ActionCancelled, "Signing cancelled")
'Signing cancelled')
if not await confirm_total(total_out - change_out, fee, coin): if not await confirm_total(total_out - change_out, fee, coin):
raise SigningError(FailureType.ActionCancelled, raise SigningError(FailureType.ActionCancelled, "Total cancelled")
'Total cancelled')
return h_first, hash143, segwit, total_in, wallet_path return h_first, hash143, segwit, total_in, wallet_path
@ -191,11 +192,14 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
# STAGE_REQUEST_SEGWIT_INPUT # STAGE_REQUEST_SEGWIT_INPUT
txi_sign = await request_tx_input(tx_req, i_sign) txi_sign = await request_tx_input(tx_req, i_sign)
is_segwit = (txi_sign.script_type == InputScriptType.SPENDWITNESS or is_segwit = (
txi_sign.script_type == InputScriptType.SPENDP2SHWITNESS) txi_sign.script_type == InputScriptType.SPENDWITNESS
or txi_sign.script_type == InputScriptType.SPENDP2SHWITNESS
)
if not is_segwit: if not is_segwit:
raise SigningError(FailureType.ProcessError, raise SigningError(
'Transaction has changed during signing') FailureType.ProcessError, "Transaction has changed during signing"
)
input_check_wallet_path(txi_sign, wallet_path) input_check_wallet_path(txi_sign, wallet_path)
key_sign = node_derive(root, txi_sign.address_n) key_sign = node_derive(root, txi_sign.address_n)
@ -203,7 +207,8 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
txi_sign.script_sig = input_derive_script(coin, txi_sign, key_sign_pub) txi_sign.script_sig = input_derive_script(coin, txi_sign, key_sign_pub)
w_txi = bytearray_with_cap( w_txi = bytearray_with_cap(
7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4) 7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4
)
if i_sign == 0: # serializing first input => prepend headers if i_sign == 0: # serializing first input => prepend headers
write_bytes(w_txi, get_tx_header(coin, tx, True)) write_bytes(w_txi, get_tx_header(coin, tx, True))
write_tx_input(w_txi, txi_sign) write_tx_input(w_txi, txi_sign)
@ -215,17 +220,21 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
txi_sign = await request_tx_input(tx_req, i_sign) txi_sign = await request_tx_input(tx_req, i_sign)
input_check_wallet_path(txi_sign, wallet_path) input_check_wallet_path(txi_sign, wallet_path)
is_bip143 = (txi_sign.script_type == InputScriptType.SPENDADDRESS or is_bip143 = (
txi_sign.script_type == InputScriptType.SPENDMULTISIG) txi_sign.script_type == InputScriptType.SPENDADDRESS
or txi_sign.script_type == InputScriptType.SPENDMULTISIG
)
if not is_bip143 or txi_sign.amount > authorized_in: if not is_bip143 or txi_sign.amount > authorized_in:
raise SigningError(FailureType.ProcessError, raise SigningError(
'Transaction has changed during signing') FailureType.ProcessError, "Transaction has changed during signing"
)
authorized_in -= txi_sign.amount authorized_in -= txi_sign.amount
key_sign = node_derive(root, txi_sign.address_n) key_sign = node_derive(root, txi_sign.address_n)
key_sign_pub = key_sign.public_key() key_sign_pub = key_sign.public_key()
hash143_hash = hash143.preimage_hash( hash143_hash = hash143.preimage_hash(
coin, tx, txi_sign, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin)) coin, tx, txi_sign, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin)
)
# if multisig, check if singing with a key that is included in multisig # if multisig, check if singing with a key that is included in multisig
if txi_sign.multisig: if txi_sign.multisig:
@ -237,9 +246,11 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
# serialize input with correct signature # serialize input with correct signature
txi_sign.script_sig = input_derive_script( txi_sign.script_sig = input_derive_script(
coin, txi_sign, key_sign_pub, signature) coin, txi_sign, key_sign_pub, signature
)
w_txi_sign = bytearray_with_cap( w_txi_sign = bytearray_with_cap(
5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4) 5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4
)
if i_sign == 0: # serializing first input => prepend headers if i_sign == 0: # serializing first input => prepend headers
write_bytes(w_txi_sign, get_tx_header(coin, tx)) write_bytes(w_txi_sign, get_tx_header(coin, tx))
write_tx_input(w_txi_sign, txi_sign) write_tx_input(w_txi_sign, txi_sign)
@ -254,10 +265,12 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
h_second = HashWriter(sha256) h_second = HashWriter(sha256)
if tx.overwintered: if tx.overwintered:
write_uint32(h_sign, tx.version | OVERWINTERED) # nVersion | fOverwintered write_uint32(
write_uint32(h_sign, coin.version_group_id) # nVersionGroupId h_sign, tx.version | OVERWINTERED
) # nVersion | fOverwintered
write_uint32(h_sign, coin.version_group_id) # nVersionGroupId
else: else:
write_uint32(h_sign, tx.version) # nVersion write_uint32(h_sign, tx.version) # nVersion
write_varint(h_sign, tx.inputs_count) write_varint(h_sign, tx.inputs_count)
@ -274,15 +287,21 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
# to the previous tx's scriptPubKey (P2PKH) or a redeem script (P2SH) # to the previous tx's scriptPubKey (P2PKH) or a redeem script (P2SH)
if txi_sign.script_type == InputScriptType.SPENDMULTISIG: if txi_sign.script_type == InputScriptType.SPENDMULTISIG:
txi_sign.script_sig = output_script_multisig( txi_sign.script_sig = output_script_multisig(
multisig_get_pubkeys(txi_sign.multisig), multisig_get_pubkeys(txi_sign.multisig), txi_sign.multisig.m
txi_sign.multisig.m) )
elif txi_sign.script_type == InputScriptType.SPENDADDRESS: elif txi_sign.script_type == InputScriptType.SPENDADDRESS:
txi_sign.script_sig = output_script_p2pkh(ecdsa_hash_pubkey(key_sign_pub)) txi_sign.script_sig = output_script_p2pkh(
ecdsa_hash_pubkey(key_sign_pub)
)
if coin.bip115: if coin.bip115:
txi_sign.script_sig += script_replay_protection_bip115(txi_sign.prev_block_hash_bip115, txi_sign.prev_block_height_bip115) txi_sign.script_sig += script_replay_protection_bip115(
txi_sign.prev_block_hash_bip115,
txi_sign.prev_block_height_bip115,
)
else: else:
raise SigningError(FailureType.ProcessError, raise SigningError(
'Unknown transaction type') FailureType.ProcessError, "Unknown transaction type"
)
else: else:
txi.script_sig = bytes() txi.script_sig = bytes()
write_tx_input(h_sign, txi) write_tx_input(h_sign, txi)
@ -300,29 +319,34 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
write_uint32(h_sign, tx.lock_time) write_uint32(h_sign, tx.lock_time)
if tx.overwintered: if tx.overwintered:
write_uint32(h_sign, tx.expiry) # expiryHeight write_uint32(h_sign, tx.expiry) # expiryHeight
write_varint(h_sign, 0) # nJoinSplit write_varint(h_sign, 0) # nJoinSplit
write_uint32(h_sign, get_hash_type(coin)) write_uint32(h_sign, get_hash_type(coin))
# check the control digests # check the control digests
if get_tx_hash(h_first, False) != get_tx_hash(h_second): if get_tx_hash(h_first, False) != get_tx_hash(h_second):
raise SigningError(FailureType.ProcessError, raise SigningError(
'Transaction has changed during signing') FailureType.ProcessError, "Transaction has changed during signing"
)
# if multisig, check if singing with a key that is included in multisig # if multisig, check if singing with a key that is included in multisig
if txi_sign.multisig: if txi_sign.multisig:
multisig_pubkey_index(txi_sign.multisig, key_sign_pub) multisig_pubkey_index(txi_sign.multisig, key_sign_pub)
# compute the signature from the tx digest # compute the signature from the tx digest
signature = ecdsa_sign(key_sign, get_tx_hash(h_sign, double=coin.sign_hash_double)) signature = ecdsa_sign(
key_sign, get_tx_hash(h_sign, double=coin.sign_hash_double)
)
tx_ser.signature_index = i_sign tx_ser.signature_index = i_sign
tx_ser.signature = signature tx_ser.signature = signature
# serialize input with correct signature # serialize input with correct signature
txi_sign.script_sig = input_derive_script( txi_sign.script_sig = input_derive_script(
coin, txi_sign, key_sign_pub, signature) coin, txi_sign, key_sign_pub, signature
)
w_txi_sign = bytearray_with_cap( w_txi_sign = bytearray_with_cap(
5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4) 5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4
)
if i_sign == 0: # serializing first input => prepend headers if i_sign == 0: # serializing first input => prepend headers
write_bytes(w_txi_sign, get_tx_header(coin, tx)) write_bytes(w_txi_sign, get_tx_header(coin, tx))
write_tx_input(w_txi_sign, txi_sign) write_tx_input(w_txi_sign, txi_sign)
@ -338,8 +362,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
txo_bin.script_pubkey = output_derive_script(txo, coin, root) txo_bin.script_pubkey = output_derive_script(txo, coin, root)
# serialize output # serialize output
w_txo_bin = bytearray_with_cap( w_txo_bin = bytearray_with_cap(5 + 8 + 5 + len(txo_bin.script_pubkey) + 4)
5 + 8 + 5 + len(txo_bin.script_pubkey) + 4)
if o == 0: # serializing first output => prepend outputs count if o == 0: # serializing first output => prepend outputs count
write_varint(w_txo_bin, tx.outputs_count) write_varint(w_txo_bin, tx.outputs_count)
write_tx_output(w_txo_bin, txo_bin) write_tx_output(w_txo_bin, txo_bin)
@ -359,23 +382,29 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
txi = await request_tx_input(tx_req, i) txi = await request_tx_input(tx_req, i)
input_check_wallet_path(txi, wallet_path) input_check_wallet_path(txi, wallet_path)
is_segwit = (txi.script_type == InputScriptType.SPENDWITNESS or is_segwit = (
txi.script_type == InputScriptType.SPENDP2SHWITNESS) txi.script_type == InputScriptType.SPENDWITNESS
or txi.script_type == InputScriptType.SPENDP2SHWITNESS
)
if not is_segwit or txi.amount > authorized_in: if not is_segwit or txi.amount > authorized_in:
raise SigningError(FailureType.ProcessError, raise SigningError(
'Transaction has changed during signing') FailureType.ProcessError, "Transaction has changed during signing"
)
authorized_in -= txi.amount authorized_in -= txi.amount
key_sign = node_derive(root, txi.address_n) key_sign = node_derive(root, txi.address_n)
key_sign_pub = key_sign.public_key() key_sign_pub = key_sign.public_key()
hash143_hash = hash143.preimage_hash( hash143_hash = hash143.preimage_hash(
coin, tx, txi, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin)) coin, tx, txi, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin)
)
signature = ecdsa_sign(key_sign, hash143_hash) signature = ecdsa_sign(key_sign, hash143_hash)
if txi.multisig: if txi.multisig:
# find out place of our signature based on the pubkey # find out place of our signature based on the pubkey
signature_index = multisig_pubkey_index(txi.multisig, key_sign_pub) signature_index = multisig_pubkey_index(txi.multisig, key_sign_pub)
witness = witness_p2wsh(txi.multisig, signature, signature_index, get_hash_type(coin)) witness = witness_p2wsh(
txi.multisig, signature, signature_index, get_hash_type(coin)
)
else: else:
witness = witness_p2wpkh(signature, key_sign_pub, get_hash_type(coin)) witness = witness_p2wpkh(signature, key_sign_pub, get_hash_type(coin))
@ -392,12 +421,14 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
write_uint32(tx_ser.serialized_tx, tx.lock_time) write_uint32(tx_ser.serialized_tx, tx.lock_time)
if tx.overwintered: if tx.overwintered:
write_uint32(tx_ser.serialized_tx, tx.expiry) # expiryHeight write_uint32(tx_ser.serialized_tx, tx.expiry) # expiryHeight
write_varint(tx_ser.serialized_tx, 0) # nJoinSplit write_varint(tx_ser.serialized_tx, 0) # nJoinSplit
await request_tx_finish(tx_req) await request_tx_finish(tx_req)
async def get_prevtx_output_value(coin: CoinInfo, tx_req: TxRequest, prev_hash: bytes, prev_index: int) -> int: async def get_prevtx_output_value(
coin: CoinInfo, tx_req: TxRequest, prev_hash: bytes, prev_index: int
) -> int:
total_out = 0 # sum of output amounts total_out = 0 # sum of output amounts
# STAGE_REQUEST_2_PREV_META # STAGE_REQUEST_2_PREV_META
@ -407,9 +438,9 @@ async def get_prevtx_output_value(coin: CoinInfo, tx_req: TxRequest, prev_hash:
if tx.overwintered: if tx.overwintered:
write_uint32(txh, tx.version | OVERWINTERED) # nVersion | fOverwintered write_uint32(txh, tx.version | OVERWINTERED) # nVersion | fOverwintered
write_uint32(txh, coin.version_group_id) # nVersionGroupId write_uint32(txh, coin.version_group_id) # nVersionGroupId
else: else:
write_uint32(txh, tx.version) # nVersion write_uint32(txh, tx.version) # nVersion
write_varint(txh, tx.inputs_cnt) write_varint(txh, tx.inputs_cnt)
@ -440,8 +471,7 @@ async def get_prevtx_output_value(coin: CoinInfo, tx_req: TxRequest, prev_hash:
ofs += len(data) ofs += len(data)
if get_tx_hash(txh, double=coin.sign_hash_double, reverse=True) != prev_hash: if get_tx_hash(txh, double=coin.sign_hash_double, reverse=True) != prev_hash:
raise SigningError(FailureType.ProcessError, raise SigningError(FailureType.ProcessError, "Encountered invalid prev_hash")
'Encountered invalid prev_hash')
return total_out return total_out
@ -463,7 +493,7 @@ def get_tx_header(coin: CoinInfo, tx: SignTx, segwit: bool = False):
w_txi = bytearray() w_txi = bytearray()
if tx.overwintered: if tx.overwintered:
write_uint32(w_txi, tx.version | OVERWINTERED) # nVersion | fOverwintered write_uint32(w_txi, tx.version | OVERWINTERED) # nVersion | fOverwintered
write_uint32(w_txi, coin.version_group_id) # nVersionGroupId write_uint32(w_txi, coin.version_group_id) # nVersionGroupId
else: else:
write_uint32(w_txi, tx.version) # nVersion write_uint32(w_txi, tx.version) # nVersion
if segwit: if segwit:
@ -482,33 +512,36 @@ def output_derive_script(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode) ->
if o.script_type == OutputScriptType.PAYTOOPRETURN: if o.script_type == OutputScriptType.PAYTOOPRETURN:
# op_return output # op_return output
if o.amount != 0: if o.amount != 0:
raise SigningError(FailureType.DataError, raise SigningError(
'OP_RETURN output with non-zero amount') FailureType.DataError, "OP_RETURN output with non-zero amount"
)
return output_script_paytoopreturn(o.op_return_data) return output_script_paytoopreturn(o.op_return_data)
if o.address_n: if o.address_n:
# change output # change output
if o.address: if o.address:
raise SigningError(FailureType.DataError, 'Address in change output') raise SigningError(FailureType.DataError, "Address in change output")
o.address = get_address_for_change(o, coin, root) o.address = get_address_for_change(o, coin, root)
else: else:
if not o.address: if not o.address:
raise SigningError(FailureType.DataError, 'Missing address') raise SigningError(FailureType.DataError, "Missing address")
if coin.bech32_prefix and o.address.startswith(coin.bech32_prefix): if coin.bech32_prefix and o.address.startswith(coin.bech32_prefix):
# p2wpkh or p2wsh # p2wpkh or p2wsh
witprog = decode_bech32_address(coin.bech32_prefix, o.address) witprog = decode_bech32_address(coin.bech32_prefix, o.address)
return output_script_native_p2wpkh_or_p2wsh(witprog) return output_script_native_p2wpkh_or_p2wsh(witprog)
if coin.cashaddr_prefix is not None and o.address.startswith(coin.cashaddr_prefix + ':'): if coin.cashaddr_prefix is not None and o.address.startswith(
prefix, addr = o.address.split(':') coin.cashaddr_prefix + ":"
):
prefix, addr = o.address.split(":")
version, data = cashaddr.decode(prefix, addr) version, data = cashaddr.decode(prefix, addr)
if version == cashaddr.ADDRESS_TYPE_P2KH: if version == cashaddr.ADDRESS_TYPE_P2KH:
version = coin.address_type version = coin.address_type
elif version == cashaddr.ADDRESS_TYPE_P2SH: elif version == cashaddr.ADDRESS_TYPE_P2SH:
version = coin.address_type_p2sh version = coin.address_type_p2sh
else: else:
raise ValueError('Unknown cashaddr address type') raise ValueError("Unknown cashaddr address type")
raw_address = bytes([version]) + data raw_address = bytes([version]) + data
else: else:
raw_address = base58.decode_check(o.address, coin.b58_hash) raw_address = base58.decode_check(o.address, coin.b58_hash)
@ -518,7 +551,9 @@ def output_derive_script(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode) ->
pubkeyhash = address_type.strip(coin.address_type, raw_address) pubkeyhash = address_type.strip(coin.address_type, raw_address)
script = output_script_p2pkh(pubkeyhash) script = output_script_p2pkh(pubkeyhash)
if coin.bip115: if coin.bip115:
script += script_replay_protection_bip115(o.block_hash_bip115, o.block_height_bip115) script += script_replay_protection_bip115(
o.block_hash_bip115, o.block_height_bip115
)
return script return script
elif address_type.check(coin.address_type_p2sh, raw_address): elif address_type.check(coin.address_type_p2sh, raw_address):
@ -526,10 +561,12 @@ def output_derive_script(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode) ->
scripthash = address_type.strip(coin.address_type_p2sh, raw_address) scripthash = address_type.strip(coin.address_type_p2sh, raw_address)
script = output_script_p2sh(scripthash) script = output_script_p2sh(scripthash)
if coin.bip115: if coin.bip115:
script += script_replay_protection_bip115(o.block_hash_bip115, o.block_height_bip115) script += script_replay_protection_bip115(
o.block_hash_bip115, o.block_height_bip115
)
return script return script
raise SigningError(FailureType.DataError, 'Invalid address type') raise SigningError(FailureType.DataError, "Invalid address type")
def get_address_for_change(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode): def get_address_for_change(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode):
@ -542,33 +579,40 @@ def get_address_for_change(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode):
elif o.script_type == OutputScriptType.PAYTOP2SHWITNESS: elif o.script_type == OutputScriptType.PAYTOP2SHWITNESS:
input_script_type = InputScriptType.SPENDP2SHWITNESS input_script_type = InputScriptType.SPENDP2SHWITNESS
else: else:
raise SigningError(FailureType.DataError, 'Invalid script type') raise SigningError(FailureType.DataError, "Invalid script type")
return get_address(input_script_type, coin, node_derive(root, o.address_n), o.multisig) return get_address(
input_script_type, coin, node_derive(root, o.address_n), o.multisig
)
def output_is_change(o: TxOutputType, wallet_path: list, segwit_in: int) -> bool: def output_is_change(o: TxOutputType, wallet_path: list, segwit_in: int) -> bool:
is_segwit = (o.script_type == OutputScriptType.PAYTOWITNESS or is_segwit = (
o.script_type == OutputScriptType.PAYTOP2SHWITNESS) o.script_type == OutputScriptType.PAYTOWITNESS
or o.script_type == OutputScriptType.PAYTOP2SHWITNESS
)
if is_segwit and o.amount > segwit_in: if is_segwit and o.amount > segwit_in:
# if the output is segwit, make sure it doesn't spend more than what the # if the output is segwit, make sure it doesn't spend more than what the
# segwit inputs paid. this is to prevent user being tricked into # segwit inputs paid. this is to prevent user being tricked into
# creating ANYONECANSPEND outputs before full segwit activation. # creating ANYONECANSPEND outputs before full segwit activation.
return False return False
return (wallet_path is not None and return (
wallet_path == o.address_n[:-_BIP32_WALLET_DEPTH] and wallet_path is not None
o.address_n[-2] <= _BIP32_CHANGE_CHAIN and and wallet_path == o.address_n[:-_BIP32_WALLET_DEPTH]
o.address_n[-1] <= _BIP32_MAX_LAST_ELEMENT) and o.address_n[-2] <= _BIP32_CHANGE_CHAIN
and o.address_n[-1] <= _BIP32_MAX_LAST_ELEMENT
)
# Tx Inputs # Tx Inputs
# === # ===
def input_derive_script(coin: CoinInfo, i: TxInputType, pubkey: bytes, signature: bytes=None) -> bytes: def input_derive_script(
coin: CoinInfo, i: TxInputType, pubkey: bytes, signature: bytes = None
) -> bytes:
if i.script_type == InputScriptType.SPENDADDRESS: if i.script_type == InputScriptType.SPENDADDRESS:
# p2pkh or p2sh # p2pkh or p2sh
return input_script_p2pkh_or_p2sh( return input_script_p2pkh_or_p2sh(pubkey, signature, get_hash_type(coin))
pubkey, signature, get_hash_type(coin))
if i.script_type == InputScriptType.SPENDP2SHWITNESS: if i.script_type == InputScriptType.SPENDP2SHWITNESS:
# p2wpkh or p2wsh using p2sh # p2wpkh or p2wsh using p2sh
@ -591,10 +635,11 @@ def input_derive_script(coin: CoinInfo, i: TxInputType, pubkey: bytes, signature
# p2sh multisig # p2sh multisig
signature_index = multisig_pubkey_index(i.multisig, pubkey) signature_index = multisig_pubkey_index(i.multisig, pubkey)
return input_script_multisig( return input_script_multisig(
i.multisig, signature, signature_index, get_hash_type(coin)) i.multisig, signature, signature_index, get_hash_type(coin)
)
else: else:
raise SigningError(FailureType.ProcessError, 'Invalid script type') raise SigningError(FailureType.ProcessError, "Invalid script type")
def input_extract_wallet_path(txi: TxInputType, wallet_path: list) -> list: def input_extract_wallet_path(txi: TxInputType, wallet_path: list) -> list:
@ -615,8 +660,9 @@ def input_check_wallet_path(txi: TxInputType, wallet_path: list) -> list:
return # there was a mismatch in Phase 1, ignore it now return # there was a mismatch in Phase 1, ignore it now
address_n = txi.address_n[:-_BIP32_WALLET_DEPTH] address_n = txi.address_n[:-_BIP32_WALLET_DEPTH]
if wallet_path != address_n: if wallet_path != address_n:
raise SigningError(FailureType.ProcessError, raise SigningError(
'Transaction has changed during signing') FailureType.ProcessError, "Transaction has changed during signing"
)
def node_derive(root: bip32.HDNode, address_n: list) -> bip32.HDNode: def node_derive(root: bip32.HDNode, address_n: list) -> bip32.HDNode:
@ -630,7 +676,9 @@ def address_n_matches_coin(address_n: list, coin: CoinInfo) -> bool:
return True # path is too short return True # path is too short
if address_n[0] not in (44 | 0x80000000, 49 | 0x80000000, 84 | 0x80000000): if address_n[0] not in (44 | 0x80000000, 49 | 0x80000000, 84 | 0x80000000):
return True # path is not BIP44/49/84 return True # path is not BIP44/49/84
return address_n[1] == (coin.slip44 | 0x80000000) # check whether coin_type matches slip44 value return address_n[1] == (
coin.slip44 | 0x80000000
) # check whether coin_type matches slip44 value
def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes: def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes:
@ -640,10 +688,8 @@ def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes:
def is_change( def is_change(
txo: TxOutputType, txo: TxOutputType, wallet_path: list, segwit_in: int, multifp: MultisigFingerprint
wallet_path: list, ) -> bool:
segwit_in: int,
multifp: MultisigFingerprint) -> bool:
if txo.multisig: if txo.multisig:
if not multifp.matches(txo.multisig): if not multifp.matches(txo.multisig):
return False return False

@ -34,14 +34,14 @@ _TXSIZE_WITNESSSCRIPT = const(34)
class TxWeightCalculator: class TxWeightCalculator:
def __init__(self, inputs_count: int, outputs_count: int): def __init__(self, inputs_count: int, outputs_count: int):
self.inputs_count = inputs_count self.inputs_count = inputs_count
self.counter = 4 * ( self.counter = 4 * (
_TXSIZE_HEADER + _TXSIZE_HEADER
_TXSIZE_FOOTER + + _TXSIZE_FOOTER
self.ser_length_size(inputs_count) + + self.ser_length_size(inputs_count)
self.ser_length_size(outputs_count)) + self.ser_length_size(outputs_count)
)
self.segwit = False self.segwit = False
def add_witness_header(self): def add_witness_header(self):
@ -53,26 +53,31 @@ class TxWeightCalculator:
def add_input(self, i: TxInputType): def add_input(self, i: TxInputType):
if i.multisig: if i.multisig:
multisig_script_size = ( multisig_script_size = _TXSIZE_MULTISIGSCRIPT + len(i.multisig.pubkeys) * (
_TXSIZE_MULTISIGSCRIPT + 1 + _TXSIZE_PUBKEY
len(i.multisig.pubkeys) * (1 + _TXSIZE_PUBKEY)) )
input_script_size = ( input_script_size = (
1 + # the OP_FALSE bug in multisig 1
i.multisig.m * (1 + _TXSIZE_SIGNATURE) + + i.multisig.m * (1 + _TXSIZE_SIGNATURE) # the OP_FALSE bug in multisig
self.op_push_size(multisig_script_size) + + self.op_push_size(multisig_script_size)
multisig_script_size) + multisig_script_size
)
else: else:
input_script_size = 1 + _TXSIZE_SIGNATURE + 1 + _TXSIZE_PUBKEY input_script_size = 1 + _TXSIZE_SIGNATURE + 1 + _TXSIZE_PUBKEY
self.counter += 4 * _TXSIZE_INPUT self.counter += 4 * _TXSIZE_INPUT
if (i.script_type == InputScriptType.SPENDADDRESS or if (
i.script_type == InputScriptType.SPENDMULTISIG): i.script_type == InputScriptType.SPENDADDRESS
or i.script_type == InputScriptType.SPENDMULTISIG
):
input_script_size += self.ser_length_size(input_script_size) input_script_size += self.ser_length_size(input_script_size)
self.counter += 4 * input_script_size self.counter += 4 * input_script_size
elif (i.script_type == InputScriptType.SPENDWITNESS or elif (
i.script_type == InputScriptType.SPENDP2SHWITNESS): i.script_type == InputScriptType.SPENDWITNESS
or i.script_type == InputScriptType.SPENDP2SHWITNESS
):
self.add_witness_header() self.add_witness_header()
if i.script_type == InputScriptType.SPENDP2SHWITNESS: if i.script_type == InputScriptType.SPENDP2SHWITNESS:
if i.multisig: if i.multisig:

@ -130,7 +130,7 @@ def bytearray_with_cap(cap: int) -> bytearray:
# === # ===
def get_tx_hash(w, double: bool=False, reverse: bool=False) -> bytes: def get_tx_hash(w, double: bool = False, reverse: bool = False) -> bytes:
d = w.get_digest() d = w.get_digest()
if double: if double:
d = sha256(d).digest() d = sha256(d).digest()

@ -21,7 +21,7 @@ async def verify_message(ctx, msg):
message = msg.message message = msg.message
address = msg.address address = msg.address
signature = msg.signature signature = msg.signature
coin_name = msg.coin_name or 'Bitcoin' coin_name = msg.coin_name or "Bitcoin"
coin = coins.by_name(coin_name) coin = coins.by_name(coin_name)
digest = message_digest(coin, message) digest = message_digest(coin, message)
@ -37,12 +37,12 @@ async def verify_message(ctx, msg):
script_type = SPENDWITNESS # native segwit script_type = SPENDWITNESS # native segwit
signature = bytes([signature[0] - 8]) + signature[1:] signature = bytes([signature[0] - 8]) + signature[1:]
else: else:
raise wire.ProcessError('Invalid signature') raise wire.ProcessError("Invalid signature")
pubkey = secp256k1.verify_recover(signature, digest) pubkey = secp256k1.verify_recover(signature, digest)
if not pubkey: if not pubkey:
raise wire.ProcessError('Invalid signature') raise wire.ProcessError("Invalid signature")
if script_type == SPENDADDRESS: if script_type == SPENDADDRESS:
addr = address_pkh(pubkey, coin) addr = address_pkh(pubkey, coin)
@ -53,21 +53,21 @@ async def verify_message(ctx, msg):
elif script_type == SPENDWITNESS: elif script_type == SPENDWITNESS:
addr = address_p2wpkh(pubkey, coin.bech32_prefix) addr = address_p2wpkh(pubkey, coin.bech32_prefix)
else: else:
raise wire.ProcessError('Invalid signature') raise wire.ProcessError("Invalid signature")
if addr != address: if addr != address:
raise wire.ProcessError('Invalid signature') raise wire.ProcessError("Invalid signature")
await require_confirm_verify_message(ctx, address_short(coin, address), message) await require_confirm_verify_message(ctx, address_short(coin, address), message)
return Success(message='Message verified') return Success(message="Message verified")
async def require_confirm_verify_message(ctx, address, message): async def require_confirm_verify_message(ctx, address, message):
text = Text('Confirm address') text = Text("Confirm address")
text.mono(*split_address(address)) text.mono(*split_address(address))
await require_confirm(ctx, text) await require_confirm(ctx, text)
text = Text('Verify message') text = Text("Verify message")
text.normal(*split_message(message)) text.normal(*split_message(message))
await require_confirm(ctx, text) await require_confirm(ctx, text)

@ -8,7 +8,7 @@ async def bootscreen():
while True: while True:
try: try:
if not config.has_pin(): if not config.has_pin():
config.unlock(pin_to_int(''), show_pin_timeout) config.unlock(pin_to_int(""), show_pin_timeout)
return return
await lockscreen() await lockscreen()
label = None label = None
@ -17,7 +17,7 @@ async def bootscreen():
if config.unlock(pin_to_int(pin), show_pin_timeout): if config.unlock(pin_to_int(pin), show_pin_timeout):
return return
else: else:
label = 'Wrong PIN, enter again' label = "Wrong PIN, enter again"
except: # noqa: E722 except: # noqa: E722
pass pass
@ -28,9 +28,9 @@ async def lockscreen():
label = storage.get_label() label = storage.get_label()
image = storage.get_homescreen() image = storage.get_homescreen()
if not label: if not label:
label = 'My TREZOR' label = "My TREZOR"
if not image: if not image:
image = res.load('apps/homescreen/res/bg.toif') image = res.load("apps/homescreen/res/bg.toif")
await ui.backlight_slide(ui.BACKLIGHT_DIM) await ui.backlight_slide(ui.BACKLIGHT_DIM)
@ -40,9 +40,11 @@ async def lockscreen():
ui.display.bar_radius(40, 100, 160, 40, ui.TITLE_GREY, ui.BG, 4) ui.display.bar_radius(40, 100, 160, 40, ui.TITLE_GREY, ui.BG, 4)
ui.display.bar_radius(42, 102, 156, 36, ui.BG, ui.TITLE_GREY, 4) ui.display.bar_radius(42, 102, 156, 36, ui.BG, ui.TITLE_GREY, 4)
ui.display.text_center(ui.WIDTH // 2, 128, 'Locked', ui.BOLD, ui.TITLE_GREY, ui.BG) ui.display.text_center(ui.WIDTH // 2, 128, "Locked", ui.BOLD, ui.TITLE_GREY, ui.BG)
ui.display.text_center(ui.WIDTH // 2 + 10, 220, 'Tap to unlock', ui.BOLD, ui.TITLE_GREY, ui.BG) ui.display.text_center(
ui.WIDTH // 2 + 10, 220, "Tap to unlock", ui.BOLD, ui.TITLE_GREY, ui.BG
)
ui.display.icon(45, 202, res.load(ui.ICON_CLICK), ui.TITLE_GREY, ui.BG) ui.display.icon(45, 202, res.load(ui.ICON_CLICK), ui.TITLE_GREY, ui.BG)
await ui.backlight_slide(ui.BACKLIGHT_NORMAL) await ui.backlight_slide(ui.BACKLIGHT_NORMAL)

@ -15,6 +15,7 @@ import apps.wallet
import apps.ethereum import apps.ethereum
import apps.lisk import apps.lisk
import apps.nem import apps.nem
if __debug__: if __debug__:
import apps.debug import apps.debug
else: else:
@ -43,5 +44,6 @@ utils.set_mode_unprivileged()
# run main event loop and specify which screen is the default # run main event loop and specify which screen is the default
from apps.homescreen.homescreen import homescreen from apps.homescreen.homescreen import homescreen
workflow.startdefault(homescreen) workflow.startdefault(homescreen)
loop.run() loop.run()

@ -70,6 +70,7 @@ async def dump_uvarint(writer, n):
# But this is harder in Python because we don't natively know the bit size of the number. # But this is harder in Python because we don't natively know the bit size of the number.
# So we have to branch on whether the number is negative. # So we have to branch on whether the number is negative.
def sint_to_uint(sint): def sint_to_uint(sint):
res = sint << 1 res = sint << 1
if sint < 0: if sint < 0:
@ -114,11 +115,10 @@ class MessageType:
setattr(self, kw, kwargs[kw]) setattr(self, kw, kwargs[kw])
def __eq__(self, rhs): def __eq__(self, rhs):
return (self.__class__ is rhs.__class__ and return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
self.__dict__ == rhs.__dict__)
def __repr__(self): def __repr__(self):
return '<%s>' % self.__class__.__name__ return "<%s>" % self.__class__.__name__
class LimitedReader: class LimitedReader:
@ -191,7 +191,7 @@ async def load_message(reader, msg_type):
elif ftype is UnicodeType: elif ftype is UnicodeType:
fvalue = bytearray(ivalue) fvalue = bytearray(ivalue)
await reader.areadinto(fvalue) await reader.areadinto(fvalue)
fvalue = str(fvalue, 'utf8') fvalue = str(fvalue, "utf8")
elif issubclass(ftype, MessageType): elif issubclass(ftype, MessageType):
fvalue = await load_message(LimitedReader(reader, ivalue), ftype) fvalue = await load_message(LimitedReader(reader, ivalue), ftype)
else: else:
@ -247,7 +247,7 @@ async def dump_message(writer, msg):
await writer.awrite(svalue) await writer.awrite(svalue)
elif ftype is UnicodeType: elif ftype is UnicodeType:
bvalue = bytes(svalue, 'utf8') bvalue = bytes(svalue, "utf8")
await dump_uvarint(writer, len(bvalue)) await dump_uvarint(writer, len(bvalue))
await writer.awrite(bvalue) await writer.awrite(bvalue)

@ -2,70 +2,70 @@ from trezorcrypto import AES
def AES_ECB_Encrypt(key: bytes) -> AES: def AES_ECB_Encrypt(key: bytes) -> AES:
''' """
Create AES encryption context in ECB mode Create AES encryption context in ECB mode
''' """
return AES(AES.ECB | AES.Encrypt, key) return AES(AES.ECB | AES.Encrypt, key)
def AES_ECB_Decrypt(key: bytes) -> AES: def AES_ECB_Decrypt(key: bytes) -> AES:
''' """
Create AES decryption context in ECB mode Create AES decryption context in ECB mode
''' """
return AES(AES.ECB | AES.Decrypt, key) return AES(AES.ECB | AES.Decrypt, key)
def AES_CBC_Encrypt(key: bytes, iv: bytes) -> AES: def AES_CBC_Encrypt(key: bytes, iv: bytes) -> AES:
''' """
Create AES encryption context in CBC mode Create AES encryption context in CBC mode
''' """
return AES(AES.CBC | AES.Encrypt, key, iv) return AES(AES.CBC | AES.Encrypt, key, iv)
def AES_CBC_Decrypt(key: bytes, iv: bytes) -> AES: def AES_CBC_Decrypt(key: bytes, iv: bytes) -> AES:
''' """
Create AES decryption context in CBC mode Create AES decryption context in CBC mode
''' """
return AES(AES.CBC | AES.Decrypt, key, iv) return AES(AES.CBC | AES.Decrypt, key, iv)
def AES_CFB_Encrypt(key: bytes, iv: bytes) -> AES: def AES_CFB_Encrypt(key: bytes, iv: bytes) -> AES:
''' """
Create AES encryption context in CFB mode Create AES encryption context in CFB mode
''' """
return AES(AES.CFB | AES.Encrypt, key, iv) return AES(AES.CFB | AES.Encrypt, key, iv)
def AES_CFB_Decrypt(key: bytes, iv: bytes) -> AES: def AES_CFB_Decrypt(key: bytes, iv: bytes) -> AES:
''' """
Create AES decryption context in CFB mode Create AES decryption context in CFB mode
''' """
return AES(AES.CFB | AES.Decrypt, key, iv) return AES(AES.CFB | AES.Decrypt, key, iv)
def AES_OFB_Encrypt(key: bytes, iv: bytes) -> AES: def AES_OFB_Encrypt(key: bytes, iv: bytes) -> AES:
''' """
Create AES encryption context in OFB mode Create AES encryption context in OFB mode
''' """
return AES(AES.OFB | AES.Encrypt, key, iv) return AES(AES.OFB | AES.Encrypt, key, iv)
def AES_OFB_Decrypt(key: bytes, iv: bytes) -> AES: def AES_OFB_Decrypt(key: bytes, iv: bytes) -> AES:
''' """
Create AES decryption context in OFB mode Create AES decryption context in OFB mode
''' """
return AES(AES.OFB | AES.Decrypt, key, iv) return AES(AES.OFB | AES.Decrypt, key, iv)
def AES_CTR_Encrypt(key: bytes) -> AES: def AES_CTR_Encrypt(key: bytes) -> AES:
''' """
Create AES encryption context in CTR mode Create AES encryption context in CTR mode
''' """
return AES(AES.CTR | AES.Encrypt, key) return AES(AES.CTR | AES.Encrypt, key)
def AES_CTR_Decrypt(key: bytes) -> AES: def AES_CTR_Decrypt(key: bytes) -> AES:
''' """
Create AES decryption context in CTR mode Create AES decryption context in CTR mode
''' """
return AES(AES.CTR | AES.Decrypt, key) return AES(AES.CTR | AES.Decrypt, key)

@ -5,7 +5,7 @@
from ubinascii import unhexlify from ubinascii import unhexlify
from ustruct import unpack from ustruct import unpack
_b32alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567' _b32alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
_b32tab = [ord(c) for c in _b32alphabet] _b32tab = [ord(c) for c in _b32alphabet]
_b32rev = dict([(ord(v), k) for k, v in enumerate(_b32alphabet)]) _b32rev = dict([(ord(v), k) for k, v in enumerate(_b32alphabet)])
@ -24,27 +24,30 @@ def encode(s: bytes) -> str:
# leftover bit of c1 and tack it onto c2. Then we take the 2 leftover # leftover bit of c1 and tack it onto c2. Then we take the 2 leftover
# bits of c2 and tack them onto c3. The shifts and masks are intended # bits of c2 and tack them onto c3. The shifts and masks are intended
# to give us values of exactly 5 bits in width. # to give us values of exactly 5 bits in width.
c1, c2, c3 = unpack('!HHB', s[i * 5:(i + 1) * 5]) c1, c2, c3 = unpack("!HHB", s[i * 5 : (i + 1) * 5])
c2 += (c1 & 1) << 16 # 17 bits wide c2 += (c1 & 1) << 16 # 17 bits wide
c3 += (c2 & 3) << 8 # 10 bits wide c3 += (c2 & 3) << 8 # 10 bits wide
encoded += bytes([_b32tab[c1 >> 11], # bits 1 - 5 encoded += bytes(
_b32tab[(c1 >> 6) & 0x1f], # bits 6 - 10 [
_b32tab[(c1 >> 1) & 0x1f], # bits 11 - 15 _b32tab[c1 >> 11], # bits 1 - 5
_b32tab[c2 >> 12], # bits 16 - 20 (1 - 5) _b32tab[(c1 >> 6) & 0x1f], # bits 6 - 10
_b32tab[(c2 >> 7) & 0x1f], # bits 21 - 25 (6 - 10) _b32tab[(c1 >> 1) & 0x1f], # bits 11 - 15
_b32tab[(c2 >> 2) & 0x1f], # bits 26 - 30 (11 - 15) _b32tab[c2 >> 12], # bits 16 - 20 (1 - 5)
_b32tab[c3 >> 5], # bits 31 - 35 (1 - 5) _b32tab[(c2 >> 7) & 0x1f], # bits 21 - 25 (6 - 10)
_b32tab[c3 & 0x1f], # bits 36 - 40 (1 - 5) _b32tab[(c2 >> 2) & 0x1f], # bits 26 - 30 (11 - 15)
]) _b32tab[c3 >> 5], # bits 31 - 35 (1 - 5)
_b32tab[c3 & 0x1f], # bits 36 - 40 (1 - 5)
]
)
# Adjust for any leftover partial quanta # Adjust for any leftover partial quanta
if leftover == 1: if leftover == 1:
encoded = encoded[:-6] + b'======' encoded = encoded[:-6] + b"======"
elif leftover == 2: elif leftover == 2:
encoded = encoded[:-4] + b'====' encoded = encoded[:-4] + b"===="
elif leftover == 3: elif leftover == 3:
encoded = encoded[:-3] + b'===' encoded = encoded[:-3] + b"==="
elif leftover == 4: elif leftover == 4:
encoded = encoded[:-1] + b'=' encoded = encoded[:-1] + b"="
return bytes(encoded).decode() return bytes(encoded).decode()
@ -53,11 +56,11 @@ def decode(s: str) -> bytes:
s = s.encode() s = s.encode()
quanta, leftover = divmod(len(s), 8) quanta, leftover = divmod(len(s), 8)
if leftover: if leftover:
raise ValueError('Incorrect padding') raise ValueError("Incorrect padding")
# Strip off pad characters from the right. We need to count the pad # Strip off pad characters from the right. We need to count the pad
# characters because this will tell us how many null bytes to remove from # characters because this will tell us how many null bytes to remove from
# the end of the decoded string. # the end of the decoded string.
padchars = s.find(b'=') padchars = s.find(b"=")
if padchars > 0: if padchars > 0:
padchars = len(s) - padchars padchars = len(s) - padchars
s = s[:-padchars] s = s[:-padchars]
@ -71,17 +74,17 @@ def decode(s: str) -> bytes:
for c in s: for c in s:
val = _b32rev.get(c) val = _b32rev.get(c)
if val is None: if val is None:
raise ValueError('Non-base32 digit found') raise ValueError("Non-base32 digit found")
acc += _b32rev[c] << shift acc += _b32rev[c] << shift
shift -= 5 shift -= 5
if shift < 0: if shift < 0:
parts.append(unhexlify(('%010x' % acc).encode())) parts.append(unhexlify(("%010x" % acc).encode()))
acc = 0 acc = 0
shift = 35 shift = 35
# Process the last, partial quanta # Process the last, partial quanta
last = unhexlify(bytes('%010x' % acc, "ascii")) last = unhexlify(bytes("%010x" % acc, "ascii"))
if padchars == 0: if padchars == 0:
last = b'' # No characters last = b"" # No characters
elif padchars == 1: elif padchars == 1:
last = last[:-1] last = last[:-1]
elif padchars == 3: elif padchars == 3:
@ -91,6 +94,6 @@ def decode(s: str) -> bytes:
elif padchars == 6: elif padchars == 6:
last = last[:-4] last = last[:-4]
else: else:
raise ValueError('Incorrect padding') raise ValueError("Incorrect padding")
parts.append(last) parts.append(last)
return b''.join(parts) return b"".join(parts)

@ -14,15 +14,15 @@
# #
# 58 character alphabet used # 58 character alphabet used
_alphabet = '123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz' _alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
def encode(data: bytes) -> str: def encode(data: bytes) -> str:
''' """
Convert bytes to base58 encoded string. Convert bytes to base58 encoded string.
''' """
origlen = len(data) origlen = len(data)
data = data.lstrip(b'\0') data = data.lstrip(b"\0")
newlen = len(data) newlen = len(data)
p, acc = 1, 0 p, acc = 1, 0
@ -30,18 +30,18 @@ def encode(data: bytes) -> str:
acc += p * c acc += p * c
p = p << 8 p = p << 8
result = '' result = ""
while acc > 0: while acc > 0:
acc, mod = divmod(acc, 58) acc, mod = divmod(acc, 58)
result += _alphabet[mod] result += _alphabet[mod]
return ''.join((c for c in reversed(result + _alphabet[0] * (origlen - newlen)))) return "".join((c for c in reversed(result + _alphabet[0] * (origlen - newlen))))
def decode(string: str) -> bytes: def decode(string: str) -> bytes:
''' """
Convert base58 encoded string to bytes. Convert base58 encoded string to bytes.
''' """
origlen = len(string) origlen = len(string)
string = string.lstrip(_alphabet[0]) string = string.lstrip(_alphabet[0])
newlen = len(string) newlen = len(string)
@ -61,28 +61,32 @@ def decode(string: str) -> bytes:
def sha256d_32(data: bytes) -> bytes: def sha256d_32(data: bytes) -> bytes:
from .hashlib import sha256 from .hashlib import sha256
return sha256(sha256(data).digest()).digest()[:4] return sha256(sha256(data).digest()).digest()[:4]
def groestl512d_32(data: bytes) -> bytes: def groestl512d_32(data: bytes) -> bytes:
from .hashlib import groestl512 from .hashlib import groestl512
return groestl512(groestl512(data).digest()).digest()[:4] return groestl512(groestl512(data).digest()).digest()[:4]
def encode_check(data: bytes, digestfunc=sha256d_32) -> str: def encode_check(data: bytes, digestfunc=sha256d_32) -> str:
''' """
Convert bytes to base58 encoded string, append checksum. Convert bytes to base58 encoded string, append checksum.
''' """
return encode(data + digestfunc(data)) return encode(data + digestfunc(data))
def decode_check(string: str, digestfunc=sha256d_32) -> bytes: def decode_check(string: str, digestfunc=sha256d_32) -> bytes:
''' """
Convert base58 encoded string to bytes and verify checksum. Convert base58 encoded string to bytes and verify checksum.
''' """
result = decode(string) result = decode(string)
digestlen = len(digestfunc(b'')) digestlen = len(digestfunc(b""))
result, check = result[:-digestlen], result[-digestlen:] result, check = result[:-digestlen], result[-digestlen:]
if check != digestfunc(result): if check != digestfunc(result):
raise ValueError('Invalid checksum') raise ValueError("Invalid checksum")
return result return result

@ -56,22 +56,23 @@ def bech32_create_checksum(hrp, data):
def bech32_encode(hrp, data): def bech32_encode(hrp, data):
"""Compute a Bech32 string given HRP and data values.""" """Compute a Bech32 string given HRP and data values."""
combined = data + bech32_create_checksum(hrp, data) combined = data + bech32_create_checksum(hrp, data)
return hrp + '1' + ''.join([CHARSET[d] for d in combined]) return hrp + "1" + "".join([CHARSET[d] for d in combined])
def bech32_decode(bech): def bech32_decode(bech):
"""Validate a Bech32 string, and determine HRP and data.""" """Validate a Bech32 string, and determine HRP and data."""
if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or (
(bech.lower() != bech and bech.upper() != bech)): bech.lower() != bech and bech.upper() != bech
):
return (None, None) return (None, None)
bech = bech.lower() bech = bech.lower()
pos = bech.rfind('1') pos = bech.rfind("1")
if pos < 1 or pos + 7 > len(bech) or len(bech) > 90: if pos < 1 or pos + 7 > len(bech) or len(bech) > 90:
return (None, None) return (None, None)
if not all(x in CHARSET for x in bech[pos + 1:]): if not all(x in CHARSET for x in bech[pos + 1 :]):
return (None, None) return (None, None)
hrp = bech[:pos] hrp = bech[:pos]
data = [CHARSET.find(x) for x in bech[pos + 1:]] data = [CHARSET.find(x) for x in bech[pos + 1 :]]
if not bech32_verify_checksum(hrp, data): if not bech32_verify_checksum(hrp, data):
return (None, None) return (None, None)
return (hrp, data[:-6]) return (hrp, data[:-6])

@ -20,7 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
CHARSET = 'qpzry9x8gf2tvdw0s3jn54khce6mua7l' CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
ADDRESS_TYPE_P2KH = 0 ADDRESS_TYPE_P2KH = 0
ADDRESS_TYPE_P2SH = 8 ADDRESS_TYPE_P2SH = 8
@ -60,7 +60,7 @@ def b32decode(inputs):
def b32encode(inputs): def b32encode(inputs):
out = '' out = ""
for char_code in inputs: for char_code in inputs:
out += CHARSET[char_code] out += CHARSET[char_code]
return out return out
@ -92,13 +92,13 @@ def encode(prefix, version, payload):
payload = bytes([version]) + payload payload = bytes([version]) + payload
payload = convertbits(payload, 8, 5) payload = convertbits(payload, 8, 5)
checksum = calculate_checksum(prefix, payload) checksum = calculate_checksum(prefix, payload)
return prefix + ':' + b32encode(payload + checksum) return prefix + ":" + b32encode(payload + checksum)
def decode(prefix, addr): def decode(prefix, addr):
addr = addr.lower() addr = addr.lower()
decoded = b32decode(addr) decoded = b32decode(addr)
if not verify_checksum(prefix, decoded): if not verify_checksum(prefix, decoded):
raise ValueError('Bad cashaddr checksum') raise ValueError("Bad cashaddr checksum")
data = bytes(convertbits(decoded, 5, 8)) data = bytes(convertbits(decoded, 5, 8))
return data[0], data[1:-6] return data[0], data[1:-6]

@ -1,4 +1,3 @@
def encode_length(l: int) -> bytes: def encode_length(l: int) -> bytes:
if l < 0x80: if l < 0x80:
return bytes([l]) return bytes([l])
@ -11,14 +10,14 @@ def encode_length(l: int) -> bytes:
def encode_int(i: bytes) -> bytes: def encode_int(i: bytes) -> bytes:
i = i.lstrip(b'\x00') i = i.lstrip(b"\x00")
if i[0] >= 0x80: if i[0] >= 0x80:
i = b'\x00' + i i = b"\x00" + i
return b'\x02' + encode_length(len(i)) + i return b"\x02" + encode_length(len(i)) + i
def encode_seq(seq: tuple) -> bytes: def encode_seq(seq: tuple) -> bytes:
res = b'' res = b""
for i in seq: for i in seq:
res += encode_int(i) res += encode_int(i)
return b'\x30' + encode_length(len(res)) + res return b"\x30" + encode_length(len(res)) + res

@ -1,6 +1,4 @@
class Hmac: class Hmac:
def __init__(self, key, msg, digestmod): def __init__(self, key, msg, digestmod):
self.digestmod = digestmod self.digestmod = digestmod
self.inner = digestmod() self.inner = digestmod()
@ -15,15 +13,15 @@ class Hmac:
self.update(msg) self.update(msg)
def update(self, msg: bytes) -> None: def update(self, msg: bytes) -> None:
''' """
Update the context with data. Update the context with data.
''' """
self.inner.update(msg) self.inner.update(msg)
def digest(self) -> bytes: def digest(self) -> bytes:
''' """
Returns the digest of processed data. Returns the digest of processed data.
''' """
outer = self.digestmod() outer = self.digestmod()
outer.update(bytes((x ^ 0x5C) for x in self.key)) outer.update(bytes((x ^ 0x5C) for x in self.key))
outer.update(self.inner.digest()) outer.update(self.inner.digest())
@ -31,7 +29,7 @@ class Hmac:
def new(key, msg, digestmod) -> Hmac: def new(key, msg, digestmod) -> Hmac:
''' """
Creates a HMAC context object. Creates a HMAC context object.
''' """
return Hmac(key, msg, digestmod) return Hmac(key, msg, digestmod)

@ -1,7 +1,6 @@
def int_to_bytes(x: int) -> bytes: def int_to_bytes(x: int) -> bytes:
if x == 0: if x == 0:
return b'' return b""
r = bytearray() r = bytearray()
while x: while x:
r.append(x % 256) r.append(x % 256)
@ -17,7 +16,7 @@ def encode_length(l: int, is_list: bool) -> bytes:
bl = int_to_bytes(l) bl = int_to_bytes(l)
return bytes([len(bl) + offset + 55]) + bl return bytes([len(bl) + offset + 55]) + bl
else: else:
raise ValueError('Input too long') raise ValueError("Input too long")
def encode(data, include_length=True) -> bytes: def encode(data, include_length=True) -> bytes:
@ -31,7 +30,7 @@ def encode(data, include_length=True) -> bytes:
else: else:
return encode_length(len(data), is_list=False) + data return encode_length(len(data), is_list=False) + data
elif isinstance(data, list): elif isinstance(data, list):
output = b'' output = b""
for item in data: for item in data:
output += encode(item) output += encode(item)
if include_length: if include_length:
@ -39,7 +38,7 @@ def encode(data, include_length=True) -> bytes:
else: else:
return output return output
else: else:
raise TypeError('Invalid input of type ' + str(type(data))) raise TypeError("Invalid input of type " + str(type(data)))
def field_length(length: int, first_byte: bytearray) -> int: def field_length(length: int, first_byte: bytearray) -> int:

@ -10,11 +10,11 @@ ERROR = const(40)
CRITICAL = const(50) CRITICAL = const(50)
_leveldict = { _leveldict = {
DEBUG: ('DEBUG', '32'), DEBUG: ("DEBUG", "32"),
INFO: ('INFO', '36'), INFO: ("INFO", "36"),
WARNING: ('WARNING', '33'), WARNING: ("WARNING", "33"),
ERROR: ('ERROR', '31'), ERROR: ("ERROR", "31"),
CRITICAL: ('CRITICAL', '1;31'), CRITICAL: ("CRITICAL", "1;31"),
} }
level = DEBUG level = DEBUG
@ -24,9 +24,14 @@ color = True
def _log(name, mlevel, msg, *args): def _log(name, mlevel, msg, *args):
if __debug__ and mlevel >= level: if __debug__ and mlevel >= level:
if color: if color:
fmt = '%d \x1b[35m%s\x1b[0m \x1b[' + _leveldict[mlevel][1] + 'm%s\x1b[0m ' + msg fmt = (
"%d \x1b[35m%s\x1b[0m \x1b["
+ _leveldict[mlevel][1]
+ "m%s\x1b[0m "
+ msg
)
else: else:
fmt = '%d %s %s ' + msg fmt = "%d %s %s " + msg
print(fmt % ((utime.ticks_us(), name, _leveldict[mlevel][0]) + args)) print(fmt % ((utime.ticks_us(), name, _leveldict[mlevel][0]) + args))
@ -47,7 +52,7 @@ def error(name, msg, *args):
def exception(name, exc): def exception(name, exc):
_log(name, ERROR, 'exception:') _log(name, ERROR, "exception:")
sys.print_exception(exc) sys.print_exception(exc)

@ -1,11 +1,11 @@
''' """
Implements an event loop with cooperative multitasking and async I/O. Tasks in Implements an event loop with cooperative multitasking and async I/O. Tasks in
the form of python coroutines (either plain generators or `async` functions) are the form of python coroutines (either plain generators or `async` functions) are
stepped through until completion, and can get asynchronously blocked by stepped through until completion, and can get asynchronously blocked by
`yield`ing or `await`ing a syscall. `yield`ing or `await`ing a syscall.
See `schedule`, `run`, and syscalls `sleep`, `wait`, `signal` and `spawn`. See `schedule`, `run`, and syscalls `sleep`, `wait`, `signal` and `spawn`.
''' """
import utime import utime
import utimeq import utimeq
@ -22,16 +22,17 @@ _paused = {}
if __debug__: if __debug__:
# for performance stats # for performance stats
import array import array
log_delay_pos = 0 log_delay_pos = 0
log_delay_rb_len = const(10) log_delay_rb_len = const(10)
log_delay_rb = array.array('i', [0] * log_delay_rb_len) log_delay_rb = array.array("i", [0] * log_delay_rb_len)
def schedule(task, value=None, deadline=None): def schedule(task, value=None, deadline=None):
''' """
Schedule task to be executed with `value` on given `deadline` (in Schedule task to be executed with `value` on given `deadline` (in
microseconds). Does not start the event loop itself, see `run`. microseconds). Does not start the event loop itself, see `run`.
''' """
if deadline is None: if deadline is None:
deadline = utime.ticks_us() deadline = utime.ticks_us()
_queue.push(deadline, task, value) _queue.push(deadline, task, value)
@ -52,12 +53,12 @@ def close(task):
def run(): def run():
''' """
Loop forever, stepping through scheduled tasks and awaiting I/O events Loop forever, stepping through scheduled tasks and awaiting I/O events
inbetween. Use `schedule` first to add a coroutine to the task queue. inbetween. Use `schedule` first to add a coroutine to the task queue.
Tasks yield back to the scheduler on any I/O, usually by calling `await` on Tasks yield back to the scheduler on any I/O, usually by calling `await` on
a `Syscall`. a `Syscall`.
''' """
if __debug__: if __debug__:
global log_delay_pos global log_delay_pos
@ -98,7 +99,7 @@ def _step(task, value):
result = task.send(value) result = task.send(value)
except StopIteration as e: except StopIteration as e:
if __debug__: if __debug__:
log.debug(__name__, 'finish: %s', task) log.debug(__name__, "finish: %s", task)
except Exception as e: except Exception as e:
if __debug__: if __debug__:
log.exception(__name__, e) log.exception(__name__, e)
@ -109,16 +110,16 @@ def _step(task, value):
schedule(task) schedule(task)
else: else:
if __debug__: if __debug__:
log.error(__name__, 'unknown syscall: %s', result) log.error(__name__, "unknown syscall: %s", result)
if after_step_hook: if after_step_hook:
after_step_hook() after_step_hook()
class Syscall: class Syscall:
''' """
When tasks want to perform any I/O, or do any sort of communication with the When tasks want to perform any I/O, or do any sort of communication with the
scheduler, they do so through instances of a class derived from `Syscall`. scheduler, they do so through instances of a class derived from `Syscall`.
''' """
def __iter__(self): def __iter__(self):
# support `yield from` or `await` on syscalls # support `yield from` or `await` on syscalls
@ -126,7 +127,7 @@ class Syscall:
class sleep(Syscall): class sleep(Syscall):
''' """
Pause current task and resume it after given delay. Although the delay is Pause current task and resume it after given delay. Although the delay is
given in microseconds, sub-millisecond precision is not guaranteed. Result given in microseconds, sub-millisecond precision is not guaranteed. Result
value is the calculated deadline. value is the calculated deadline.
@ -135,7 +136,7 @@ class sleep(Syscall):
>>> planned = await loop.sleep(1000 * 1000) # sleep for 1ms >>> planned = await loop.sleep(1000 * 1000) # sleep for 1ms
>>> print('missed by %d us', utime.ticks_diff(utime.ticks_us(), planned)) >>> print('missed by %d us', utime.ticks_diff(utime.ticks_us(), planned))
''' """
def __init__(self, delay_us): def __init__(self, delay_us):
self.delay_us = delay_us self.delay_us = delay_us
@ -146,7 +147,7 @@ class sleep(Syscall):
class wait(Syscall): class wait(Syscall):
''' """
Pause current task, and resume only after a message on `msg_iface` is Pause current task, and resume only after a message on `msg_iface` is
received. Messages are received either from an USB interface, or the received. Messages are received either from an USB interface, or the
touch display. Result value a tuple of message values. touch display. Result value a tuple of message values.
@ -155,7 +156,7 @@ class wait(Syscall):
>>> hid_report, = await loop.wait(0xABCD) # await USB HID report >>> hid_report, = await loop.wait(0xABCD) # await USB HID report
>>> event, x, y = await loop.wait(io.TOUCH) # await touch event >>> event, x, y = await loop.wait(io.TOUCH) # await touch event
''' """
def __init__(self, msg_iface): def __init__(self, msg_iface):
self.msg_iface = msg_iface self.msg_iface = msg_iface
@ -168,7 +169,7 @@ _NO_VALUE = ()
class signal(Syscall): class signal(Syscall):
''' """
Pause current task, and let other running task to resume it later with a Pause current task, and let other running task to resume it later with a
result value or an exception. result value or an exception.
@ -181,7 +182,7 @@ class signal(Syscall):
>>> # in task #2: >>> # in task #2:
>>> signal.send('hello from task #2') >>> signal.send('hello from task #2')
>>> # prints in the next iteration of the event loop >>> # prints in the next iteration of the event loop
''' """
def __init__(self): def __init__(self):
self.value = _NO_VALUE self.value = _NO_VALUE
@ -210,7 +211,7 @@ class signal(Syscall):
class spawn(Syscall): class spawn(Syscall):
''' """
Execute one or more children tasks and wait until one of them exits. Execute one or more children tasks and wait until one of them exits.
Return value of `spawn` is the return value of task that triggered the Return value of `spawn` is the return value of task that triggered the
completion. By default, `spawn` returns after the first child completes, and completion. By default, `spawn` returns after the first child completes, and
@ -232,7 +233,7 @@ class spawn(Syscall):
Note: You should not directly `yield` a `spawn` instance, see logic in Note: You should not directly `yield` a `spawn` instance, see logic in
`spawn.__iter__` for explanation. Always use `await`. `spawn.__iter__` for explanation. Always use `await`.
''' """
def __init__(self, *children, exit_others=True): def __init__(self, *children, exit_others=True):
self.children = children self.children = children
@ -281,7 +282,6 @@ class spawn(Syscall):
class put(Syscall): class put(Syscall):
def __init__(self, ch, value=None): def __init__(self, ch, value=None):
self.ch = ch self.ch = ch
self.value = value self.value = value
@ -295,7 +295,6 @@ class put(Syscall):
class take(Syscall): class take(Syscall):
def __init__(self, ch): def __init__(self, ch):
self.ch = ch self.ch = ch
@ -308,7 +307,6 @@ class take(Syscall):
class chan: class chan:
def __init__(self, id=None): def __init__(self, id=None):
self.id = id self.id = id
self.putters = [] self.putters = []

@ -2,18 +2,38 @@ from trezor import ui
def pin_to_int(pin: str) -> int: def pin_to_int(pin: str) -> int:
return int('1' + pin) return int("1" + pin)
def show_pin_timeout(seconds: int, progress: int): def show_pin_timeout(seconds: int, progress: int):
if progress == 0: if progress == 0:
ui.display.bar(0, 0, ui.WIDTH, ui.HEIGHT, ui.BG) ui.display.bar(0, 0, ui.WIDTH, ui.HEIGHT, ui.BG)
ui.display.text_center(ui.WIDTH // 2, 37, 'Verifying PIN', ui.BOLD, ui.FG, ui.BG, ui.WIDTH) ui.display.text_center(
ui.WIDTH // 2, 37, "Verifying PIN", ui.BOLD, ui.FG, ui.BG, ui.WIDTH
)
ui.display.loader(progress, 0, ui.FG, ui.BG) ui.display.loader(progress, 0, ui.FG, ui.BG)
if seconds == 0: if seconds == 0:
ui.display.text_center(ui.WIDTH // 2, ui.HEIGHT - 22, 'Done', ui.BOLD, ui.FG, ui.BG, ui.WIDTH) ui.display.text_center(
ui.WIDTH // 2, ui.HEIGHT - 22, "Done", ui.BOLD, ui.FG, ui.BG, ui.WIDTH
)
elif seconds == 1: elif seconds == 1:
ui.display.text_center(ui.WIDTH // 2, ui.HEIGHT - 22, '1 second left', ui.BOLD, ui.FG, ui.BG, ui.WIDTH) ui.display.text_center(
ui.WIDTH // 2,
ui.HEIGHT - 22,
"1 second left",
ui.BOLD,
ui.FG,
ui.BG,
ui.WIDTH,
)
else: else:
ui.display.text_center(ui.WIDTH // 2, ui.HEIGHT - 22, '%d seconds left' % seconds, ui.BOLD, ui.FG, ui.BG, ui.WIDTH) ui.display.text_center(
ui.WIDTH // 2,
ui.HEIGHT - 22,
"%d seconds left" % seconds,
ui.BOLD,
ui.FG,
ui.BG,
ui.WIDTH,
)
ui.display.refresh() ui.display.refresh()

@ -5,16 +5,16 @@ except ImportError:
def load(name): def load(name):
''' """
Loads resource of a given name as bytes. Loads resource of a given name as bytes.
''' """
return resdata[name] return resdata[name]
def gettext(message): def gettext(message):
''' """
Returns localized string. This function is aliased to _. Returns localized string. This function is aliased to _.
''' """
return message return message

@ -18,13 +18,12 @@ BORDER = const(4) # border size in pixels
class Button(LazyWidget): class Button(LazyWidget):
def __init__(self, area: tuple, content: str, style: dict = ui.BTN_KEY): def __init__(self, area: tuple, content: str, style: dict = ui.BTN_KEY):
self.area = area self.area = area
self.content = content self.content = content
self.normal_style = style['normal'] or ui.BTN_KEY['normal'] self.normal_style = style["normal"] or ui.BTN_KEY["normal"]
self.active_style = style['active'] or ui.BTN_KEY['active'] self.active_style = style["active"] or ui.BTN_KEY["active"]
self.disabled_style = style['disabled'] or ui.BTN_KEY['disabled'] self.disabled_style = style["disabled"] or ui.BTN_KEY["disabled"]
self.state = BTN_INITIAL self.state = BTN_INITIAL
def enable(self): def enable(self):
@ -50,28 +49,24 @@ class Button(LazyWidget):
self.render_content(s, ax, ay, aw, ah) self.render_content(s, ax, ay, aw, ah)
def render_background(self, s, ax, ay, aw, ah): def render_background(self, s, ax, ay, aw, ah):
radius = s['radius'] radius = s["radius"]
bg_color = s['bg-color'] bg_color = s["bg-color"]
border_color = s['border-color'] border_color = s["border-color"]
if border_color != bg_color: if border_color != bg_color:
# render border and background on top of it # render border and background on top of it
display.bar_radius(ax, ay, display.bar_radius(ax, ay, aw, ah, border_color, ui.BG, radius)
aw, ah, display.bar_radius(
border_color, ax + BORDER,
ui.BG, ay + BORDER,
radius) aw - BORDER * 2,
display.bar_radius(ax + BORDER, ay + BORDER, ah - BORDER * 2,
aw - BORDER * 2, ah - BORDER * 2, bg_color,
bg_color, border_color,
border_color, radius,
radius) )
else: else:
# render only the background # render only the background
display.bar_radius(ax, ay, display.bar_radius(ax, ay, aw, ah, bg_color, ui.BG, radius)
aw, ah,
bg_color,
ui.BG,
radius)
def render_content(self, s, ax, ay, aw, ah): def render_content(self, s, ax, ay, aw, ah):
c = self.content c = self.content
@ -79,10 +74,10 @@ class Button(LazyWidget):
ty = ay + ah // 2 + 8 ty = ay + ah // 2 + 8
if isinstance(c, str): if isinstance(c, str):
display.text_center( display.text_center(
tx, ty, c, s['text-style'], s['fg-color'], s['bg-color']) tx, ty, c, s["text-style"], s["fg-color"], s["bg-color"]
)
else: else:
display.icon( display.icon(tx - ICON // 2, ty - ICON, c, s["fg-color"], s["bg-color"])
tx - ICON // 2, ty - ICON, c, s['fg-color'], s['bg-color'])
def touch(self, event, pos): def touch(self, event, pos):
pos = rotate(pos) pos = rotate(pos)

@ -15,21 +15,20 @@ DEFAULT_CANCEL = res.load(ui.ICON_CANCEL)
class ConfirmDialog(Widget): class ConfirmDialog(Widget):
def __init__(self, def __init__(
content, self,
confirm=DEFAULT_CONFIRM, content,
cancel=DEFAULT_CANCEL, confirm=DEFAULT_CONFIRM,
confirm_style=ui.BTN_CONFIRM, cancel=DEFAULT_CANCEL,
cancel_style=ui.BTN_CANCEL): confirm_style=ui.BTN_CONFIRM,
cancel_style=ui.BTN_CANCEL,
):
self.content = content self.content = content
if cancel is not None: if cancel is not None:
self.confirm = Button( self.confirm = Button(ui.grid(9, n_x=2), confirm, style=confirm_style)
ui.grid(9, n_x=2), confirm, style=confirm_style) self.cancel = Button(ui.grid(8, n_x=2), cancel, style=cancel_style)
self.cancel = Button(
ui.grid(8, n_x=2), cancel, style=cancel_style)
else: else:
self.confirm = Button( self.confirm = Button(ui.grid(4, n_x=1), confirm, style=confirm_style)
ui.grid(4, n_x=1), confirm, style=confirm_style)
self.cancel = None self.cancel = None
def render(self): def render(self):
@ -56,12 +55,13 @@ _STOPPED = const(-2)
class HoldToConfirmDialog(Widget): class HoldToConfirmDialog(Widget):
def __init__(
def __init__(self, self,
content, content,
hold='Hold to confirm', hold="Hold to confirm",
button_style=ui.BTN_CONFIRM, button_style=ui.BTN_CONFIRM,
loader_style=ui.LDR_DEFAULT): loader_style=ui.LDR_DEFAULT,
):
self.content = content self.content = content
self.button = Button(ui.grid(4, n_x=1), hold, style=button_style) self.button = Button(ui.grid(4, n_x=1), hold, style=button_style)
self.loader = Loader(style=loader_style) self.loader = Loader(style=loader_style)

@ -2,7 +2,6 @@ from trezor.ui import Widget
class Container(Widget): class Container(Widget):
def __init__(self, *children): def __init__(self, *children):
self.children = children self.children = children

@ -9,11 +9,10 @@ HOST = const(1)
class EntrySelector(Widget): class EntrySelector(Widget):
def __init__(self, content): def __init__(self, content):
self.content = content self.content = content
self.device = Button(ui.grid(8, n_y=4, n_x=4, cells_x=4), 'Device') self.device = Button(ui.grid(8, n_y=4, n_x=4, cells_x=4), "Device")
self.host = Button(ui.grid(12, n_y=4, n_x=4, cells_x=4), 'Host') self.host = Button(ui.grid(12, n_y=4, n_x=4, cells_x=4), "Host")
def render(self): def render(self):
self.device.render() self.device.render()

@ -8,11 +8,10 @@ _SHRINK_BY = const(2)
class Loader(ui.Widget): class Loader(ui.Widget):
def __init__(self, style=ui.LDR_DEFAULT): def __init__(self, style=ui.LDR_DEFAULT):
self.target_ms = _TARGET_MS self.target_ms = _TARGET_MS
self.normal_style = style['normal'] or ui.LDR_DEFAULT['normal'] self.normal_style = style["normal"] or ui.LDR_DEFAULT["normal"]
self.active_style = style['active'] or ui.LDR_DEFAULT['active'] self.active_style = style["active"] or ui.LDR_DEFAULT["active"]
self.start_ms = None self.start_ms = None
self.stop_ms = None self.stop_ms = None
@ -47,15 +46,19 @@ class Loader(ui.Widget):
s = self.active_style s = self.active_style
else: else:
s = self.normal_style s = self.normal_style
if s['icon'] is None: if s["icon"] is None:
ui.display.loader( ui.display.loader(r, -24, s["fg-color"], s["bg-color"])
r, -24, s['fg-color'], s['bg-color']) elif s["icon-fg-color"] is None:
elif s['icon-fg-color'] is None: ui.display.loader(r, -24, s["fg-color"], s["bg-color"], res.load(s["icon"]))
ui.display.loader(
r, -24, s['fg-color'], s['bg-color'], res.load(s['icon']))
else: else:
ui.display.loader( ui.display.loader(
r, -24, s['fg-color'], s['bg-color'], res.load(s['icon']), s['icon-fg-color']) r,
-24,
s["fg-color"],
s["bg-color"],
res.load(s["icon"]),
s["icon-fg-color"],
)
def __iter__(self): def __iter__(self):
sleep = loop.sleep(1000000 // 30) # 30 fps sleep = loop.sleep(1000000 // 30) # 30 fps

@ -6,7 +6,7 @@ from trezor.ui.button import BTN_CLICKED, ICON, Button
if __debug__: if __debug__:
from apps.debug import input_signal from apps.debug import input_signal
MNEMONIC_KEYS = ('abc', 'def', 'ghi', 'jkl', 'mno', 'pqr', 'stu', 'vwx', 'yz') MNEMONIC_KEYS = ("abc", "def", "ghi", "jkl", "mno", "pqr", "stu", "vwx", "yz")
def key_buttons(keys): def key_buttons(keys):
@ -24,7 +24,7 @@ def compute_mask(text: str) -> int:
class Input(Button): class Input(Button):
def __init__(self, area: tuple, content: str='', word: str=''): def __init__(self, area: tuple, content: str = "", word: str = ""):
super().__init__(area, content) super().__init__(area, content)
self.word = word self.word = word
self.icon = None self.icon = None
@ -37,26 +37,26 @@ class Input(Button):
self.taint() self.taint()
if content == word: # confirm button if content == word: # confirm button
self.enable() self.enable()
self.normal_style = ui.BTN_KEY_CONFIRM['normal'] self.normal_style = ui.BTN_KEY_CONFIRM["normal"]
self.active_style = ui.BTN_KEY_CONFIRM['active'] self.active_style = ui.BTN_KEY_CONFIRM["active"]
self.icon = ui.ICON_CONFIRM self.icon = ui.ICON_CONFIRM
elif word: # auto-complete button elif word: # auto-complete button
self.enable() self.enable()
self.normal_style = ui.BTN_KEY['normal'] self.normal_style = ui.BTN_KEY["normal"]
self.active_style = ui.BTN_KEY['active'] self.active_style = ui.BTN_KEY["active"]
self.icon = ui.ICON_CLICK self.icon = ui.ICON_CLICK
else: # disabled button else: # disabled button
self.disable() self.disable()
self.icon = None self.icon = None
def render_content(self, s, ax, ay, aw, ah): def render_content(self, s, ax, ay, aw, ah):
text_style = s['text-style'] text_style = s["text-style"]
fg_color = s['fg-color'] fg_color = s["fg-color"]
bg_color = s['bg-color'] bg_color = s["bg-color"]
p = self.pending # should we draw the pending marker? p = self.pending # should we draw the pending marker?
t = self.content # input content t = self.content # input content
w = self.word[len(t):] # suggested word w = self.word[len(t) :] # suggested word
i = self.icon # rendered icon i = self.icon # rendered icon
tx = ax + 24 # x-offset of the content tx = ax + 24 # x-offset of the content
@ -79,12 +79,12 @@ class Input(Button):
class MnemonicKeyboard(ui.Widget): class MnemonicKeyboard(ui.Widget):
def __init__(self, prompt: str=''): def __init__(self, prompt: str = ""):
self.prompt = prompt self.prompt = prompt
self.input = Input(ui.grid(1, n_x=4, n_y=4, cells_x=3), '', '') self.input = Input(ui.grid(1, n_x=4, n_y=4, cells_x=3), "", "")
self.back = Button(ui.grid(0, n_x=4, n_y=4), self.back = Button(
res.load(ui.ICON_BACK), ui.grid(0, n_x=4, n_y=4), res.load(ui.ICON_BACK), style=ui.BTN_CLEAR
style=ui.BTN_CLEAR) )
self.keys = key_buttons(MNEMONIC_KEYS) self.keys = key_buttons(MNEMONIC_KEYS)
self.pbutton = None # pending key button self.pbutton = None # pending key button
self.pindex = 0 # index of current pending char in pbutton self.pindex = 0 # index of current pending char in pbutton
@ -114,7 +114,7 @@ class MnemonicKeyboard(ui.Widget):
if self.input.touch(event, pos) == BTN_CLICKED: if self.input.touch(event, pos) == BTN_CLICKED:
# input press, either auto-complete or confirm # input press, either auto-complete or confirm
if word and content == word: if word and content == word:
self.edit('') self.edit("")
return content return content
else: else:
self.edit(word) self.edit(word)
@ -133,7 +133,7 @@ class MnemonicKeyboard(ui.Widget):
return return
def edit(self, content, button=None, index=0): def edit(self, content, button=None, index=0):
word = bip39.find_word(content) or '' word = bip39.find_word(content) or ""
mask = bip39.complete_word(content) mask = bip39.complete_word(content)
self.pbutton = button self.pbutton = button

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save