src: run black

pull/25/head
Jan Pochyla 6 years ago
parent ead154b907
commit dcb15f77c3

@ -17,7 +17,14 @@ def addrtype_bytes(address_type: int):
if address_type <= 0xFFFFFF:
return bytes([(address_type >> 16), (address_type >> 8), (address_type & 0xFF)])
# else
return bytes([(address_type >> 24), (address_type >> 16), (address_type >> 8), (address_type & 0xFF)])
return bytes(
[
(address_type >> 24),
(address_type >> 16),
(address_type >> 8),
(address_type & 0xFF),
]
)
def check(address_type, raw_address):
@ -26,25 +33,36 @@ def check(address_type, raw_address):
if address_type <= 0xFFFF:
return address_type == (raw_address[0] << 8) | raw_address[1]
if address_type <= 0xFFFFFF:
return address_type == (raw_address[0] << 16) | (raw_address[1] << 8) | raw_address[2]
return (
address_type
== (raw_address[0] << 16) | (raw_address[1] << 8) | raw_address[2]
)
# else
return address_type == (raw_address[0] << 24) | (raw_address[1] << 16) | (raw_address[2] << 8) | raw_address[3]
return (
address_type
== (raw_address[0] << 24)
| (raw_address[1] << 16)
| (raw_address[2] << 8)
| raw_address[3]
)
def strip(address_type, raw_address):
if not check(address_type, raw_address):
raise ValueError('Invalid address')
raise ValueError("Invalid address")
l = length(address_type)
return raw_address[l:]
def split(coin, raw_address):
for f in ('address_type',
'address_type_p2sh',
'address_type_p2wpkh',
'address_type_p2wsh'):
for f in (
"address_type",
"address_type_p2sh",
"address_type_p2wpkh",
"address_type_p2wsh",
):
at = getattr(coin, f)
if at is not None and check(at, raw_address):
l = length(at)
return raw_address[:l], raw_address[l:]
raise ValueError('Invalid address')
raise ValueError("Invalid address")

@ -53,6 +53,6 @@ def set_passphrase(passphrase):
def clear(skip_passphrase: bool = False):
set_seed(None)
if skip_passphrase:
set_passphrase('')
set_passphrase("")
else:
set_passphrase(None)

@ -1,7 +1,7 @@
from trezor.crypto.base58 import groestl512d_32, sha256d_32
class CoinInfo:
class CoinInfo:
def __init__(
self,
coin_name: str,
@ -37,7 +37,7 @@ class CoinInfo:
self.version_group_id = version_group_id
self.bip115 = bip115
self.curve_name = curve_name
if curve_name == 'secp256k1-groestl':
if curve_name == "secp256k1-groestl":
self.b58_hash = groestl512d_32
self.sign_hash_double = False
else:

@ -19,11 +19,11 @@ def by_address_type(address_type):
for c in COINS:
if c.address_type == address_type:
return c
raise ValueError('Unknown coin address type %d' % address_type)
raise ValueError("Unknown coin address type %d" % address_type)
def by_slip44(slip44):
for c in COINS:
if c.slip44 == slip44:
return c
raise ValueError('Unknown coin slip44 index %d' % slip44)
raise ValueError("Unknown coin slip44 index %d" % slip44)

@ -21,7 +21,7 @@ async def hold_to_confirm(ctx, content, code=None, *args, **kwargs):
code = ButtonRequestType.Other
await ctx.call(ButtonRequest(code=code), MessageType.ButtonAck)
dialog = HoldToConfirmDialog(content, 'Hold to confirm', *args, **kwargs)
dialog = HoldToConfirmDialog(content, "Hold to confirm", *args, **kwargs)
return await ctx.wait(dialog) == CONFIRMED
@ -29,10 +29,10 @@ async def hold_to_confirm(ctx, content, code=None, *args, **kwargs):
async def require_confirm(*args, **kwargs):
confirmed = await confirm(*args, **kwargs)
if not confirmed:
raise wire.ActionCancelled('Cancelled')
raise wire.ActionCancelled("Cancelled")
async def require_hold_to_confirm(*args, **kwargs):
confirmed = await hold_to_confirm(*args, **kwargs)
if not confirmed:
raise wire.ActionCancelled('Cancelled')
raise wire.ActionCancelled("Cancelled")

@ -12,14 +12,11 @@ from apps.common.confirm import confirm
async def show_address(ctx, address: str):
lines = split_address(address)
text = Text('Confirm address', ui.ICON_RECEIVE, icon_color=ui.GREEN)
text = Text("Confirm address", ui.ICON_RECEIVE, icon_color=ui.GREEN)
text.mono(*lines)
return await confirm(
ctx,
text,
code=ButtonRequestType.Address,
cancel='QR',
cancel_style=ui.BTN_KEY)
ctx, text, code=ButtonRequestType.Address, cancel="QR", cancel_style=ui.BTN_KEY
)
async def show_qr(ctx, address: str):
@ -28,14 +25,15 @@ async def show_qr(ctx, address: str):
qr_coef = const(4)
qr = Qr(address, (qr_x, qr_y), qr_coef)
text = Text('Confirm address', ui.ICON_RECEIVE, icon_color=ui.GREEN)
text = Text("Confirm address", ui.ICON_RECEIVE, icon_color=ui.GREEN)
content = Container(qr, text)
return await confirm(
ctx,
content,
code=ButtonRequestType.Address,
cancel='Address',
cancel_style=ui.BTN_KEY)
cancel="Address",
cancel_style=ui.BTN_KEY,
)
def split_address(address: str):

@ -13,16 +13,17 @@ from apps.common.cache import get_state
@ui.layout
async def request_passphrase_entry(ctx):
text = Text('Enter passphrase', ui.ICON_CONFIG)
text.normal('Where to enter your', 'passphrase?')
text = Text("Enter passphrase", ui.ICON_CONFIG)
text.normal("Where to enter your", "passphrase?")
text.render()
ack = await ctx.call(
ButtonRequest(code=ButtonRequestType.PassphraseType),
MessageType.ButtonAck,
MessageType.Cancel)
MessageType.Cancel,
)
if ack.MESSAGE_WIRE_TYPE == MessageType.Cancel:
raise wire.ActionCancelled('Passphrase cancelled')
raise wire.ActionCancelled("Passphrase cancelled")
selector = EntrySelector(text)
return await ctx.wait(selector)
@ -31,28 +32,30 @@ async def request_passphrase_entry(ctx):
@ui.layout
async def request_passphrase_ack(ctx, on_device):
if not on_device:
text = Text('Passphrase entry', ui.ICON_CONFIG)
text.normal('Please, type passphrase', 'on connected host.')
text = Text("Passphrase entry", ui.ICON_CONFIG)
text.normal("Please, type passphrase", "on connected host.")
text.render()
req = PassphraseRequest(on_device=on_device)
ack = await ctx.call(req, MessageType.PassphraseAck, MessageType.Cancel)
if ack.MESSAGE_WIRE_TYPE == MessageType.Cancel:
raise wire.ActionCancelled('Passphrase cancelled')
raise wire.ActionCancelled("Passphrase cancelled")
if on_device:
if ack.passphrase is not None:
raise wire.ProcessError('Passphrase provided when it should not be')
keyboard = PassphraseKeyboard('Enter passphrase')
raise wire.ProcessError("Passphrase provided when it should not be")
keyboard = PassphraseKeyboard("Enter passphrase")
passphrase = await ctx.wait(keyboard)
if passphrase == CANCELLED:
raise wire.ActionCancelled('Passphrase cancelled')
raise wire.ActionCancelled("Passphrase cancelled")
else:
if ack.passphrase is None:
raise wire.ProcessError('Passphrase not provided')
raise wire.ProcessError("Passphrase not provided")
passphrase = ack.passphrase
req = PassphraseStateRequest(state=get_state(prev_state=ack.state, passphrase=passphrase))
req = PassphraseStateRequest(
state=get_state(prev_state=ack.state, passphrase=passphrase)
)
ack = await ctx.call(req, MessageType.PassphraseStateAck, MessageType.Cancel)
return passphrase
@ -71,4 +74,4 @@ async def protect_by_passphrase(ctx):
if storage.has_passphrase():
return await request_passphrase(ctx)
else:
return ''
return ""

@ -11,32 +11,31 @@ class PinCancelled(Exception):
@ui.layout
async def request_pin(label=None, cancellable: bool=True) -> str:
async def request_pin(label=None, cancellable: bool = True) -> str:
def onchange():
c = dialog.cancel
if matrix.pin:
back = res.load(ui.ICON_BACK)
if c.content is not back:
c.normal_style = ui.BTN_CLEAR['normal']
c.normal_style = ui.BTN_CLEAR["normal"]
c.content = back
c.enable()
c.taint()
else:
lock = res.load(ui.ICON_LOCK)
if not cancellable and c.content:
c.content = ''
c.content = ""
c.disable()
c.taint()
elif c.content is not lock:
c.normal_style = ui.BTN_CANCEL['normal']
c.normal_style = ui.BTN_CANCEL["normal"]
c.content = lock
c.enable()
c.taint()
c.render()
if label is None:
label = 'Enter your PIN'
label = "Enter your PIN"
matrix = PinMatrix(label)
matrix.onchange = onchange
dialog = ConfirmDialog(matrix)
@ -54,7 +53,7 @@ async def request_pin(label=None, cancellable: bool=True) -> str:
if result == CONFIRMED:
return matrix.pin
elif matrix.pin: # reset
matrix.change('')
matrix.change("")
continue
else: # cancel
raise PinCancelled()

@ -4,10 +4,12 @@ from trezor.crypto import bip32, bip39
from apps.common import cache, storage
from apps.common.request_passphrase import protect_by_passphrase
_DEFAULT_CURVE = 'secp256k1'
_DEFAULT_CURVE = "secp256k1"
async def derive_node(ctx: wire.Context, path: list, curve_name: str = _DEFAULT_CURVE) -> bip32.HDNode:
async def derive_node(
ctx: wire.Context, path: list, curve_name: str = _DEFAULT_CURVE
) -> bip32.HDNode:
seed = await _get_cached_seed(ctx)
node = bip32.from_seed(seed, curve_name)
node.derive_path(path)
@ -16,7 +18,7 @@ async def derive_node(ctx: wire.Context, path: list, curve_name: str = _DEFAULT_
async def _get_cached_seed(ctx: wire.Context) -> bytes:
if not storage.is_initialized():
raise wire.ProcessError('Device is not initialized')
raise wire.ProcessError("Device is not initialized")
if cache.get_seed() is None:
passphrase = await _get_cached_passphrase(ctx)
seed = bip39.seed(storage.get_mnemonic(), passphrase)
@ -31,11 +33,13 @@ async def _get_cached_passphrase(ctx: wire.Context) -> str:
return cache.get_passphrase()
def derive_node_without_passphrase(path: list, curve_name: str = _DEFAULT_CURVE) -> bip32.HDNode:
def derive_node_without_passphrase(
path: list, curve_name: str = _DEFAULT_CURVE
) -> bip32.HDNode:
if not storage.is_initialized():
raise Exception('Device is not initialized')
raise Exception("Device is not initialized")
seed = bip39.seed(storage.get_mnemonic(), '')
seed = bip39.seed(storage.get_mnemonic(), "")
node = bip32.from_seed(seed, curve_name)
node.derive_path(path)
return node

@ -21,7 +21,6 @@ def message_digest(coin, message):
def split_message(message):
def measure(s):
return ui.display.text_width(s, ui.NORMAL)

@ -8,7 +8,7 @@ from apps.common import cache
HOMESCREEN_MAXSIZE = 16384
_STORAGE_VERSION = b'\x01'
_STORAGE_VERSION = b"\x01"
# fmt: off
_APP = const(0x01) # app namespace
@ -64,9 +64,9 @@ def load_mnemonic(mnemonic: str, needs_backup: bool) -> None:
config.set(_APP, _MNEMONIC, mnemonic.encode())
config.set(_APP, _VERSION, _STORAGE_VERSION)
if needs_backup:
config.set(_APP, _NEEDS_BACKUP, b'\x01')
config.set(_APP, _NEEDS_BACKUP, b"\x01")
else:
config.set(_APP, _NEEDS_BACKUP, b'')
config.set(_APP, _NEEDS_BACKUP, b"")
def needs_backup() -> bool:
@ -74,7 +74,7 @@ def needs_backup() -> bool:
def set_backed_up() -> None:
config.set(_APP, _NEEDS_BACKUP, b'')
config.set(_APP, _NEEDS_BACKUP, b"")
def unfinished_backup() -> bool:
@ -83,34 +83,39 @@ def unfinished_backup() -> bool:
def set_unfinished_backup(state: bool) -> None:
if state:
config.set(_APP, _UNFINISHED_BACKUP, b'\x01')
config.set(_APP, _UNFINISHED_BACKUP, b"\x01")
else:
config.set(_APP, _UNFINISHED_BACKUP, b'')
config.set(_APP, _UNFINISHED_BACKUP, b"")
def get_passphrase_source() -> int:
b = config.get(_APP, _PASSPHRASE_SOURCE)
if b == b'\x01':
if b == b"\x01":
return 1
elif b == b'\x02':
elif b == b"\x02":
return 2
else:
return 0
def load_settings(label: str=None, use_passphrase: bool=None, homescreen: bytes=None, passphrase_source: int=None) -> None:
def load_settings(
label: str = None,
use_passphrase: bool = None,
homescreen: bytes = None,
passphrase_source: int = None,
) -> None:
if label is not None:
config.set(_APP, _LABEL, label.encode(), True) # public
if use_passphrase is True:
config.set(_APP, _USE_PASSPHRASE, b'\x01')
config.set(_APP, _USE_PASSPHRASE, b"\x01")
if use_passphrase is False:
config.set(_APP, _USE_PASSPHRASE, b'')
config.set(_APP, _USE_PASSPHRASE, b"")
if homescreen is not None:
if homescreen[:8] == b'TOIf\x90\x00\x90\x00':
if homescreen[:8] == b"TOIf\x90\x00\x90\x00":
if len(homescreen) <= HOMESCREEN_MAXSIZE:
config.set(_APP, _HOMESCREEN, homescreen, True) # public
else:
config.set(_APP, _HOMESCREEN, b'', True) # public
config.set(_APP, _HOMESCREEN, b"", True) # public
if passphrase_source is not None:
if passphrase_source in [0, 1, 2]:
config.set(_APP, _PASSPHRASE_SOURCE, bytes([passphrase_source]))
@ -121,7 +126,7 @@ def get_flags() -> int:
if b is None:
return 0
else:
return int.from_bytes(b, 'big')
return int.from_bytes(b, "big")
def set_flags(flags: int) -> None:
@ -129,10 +134,10 @@ def set_flags(flags: int) -> None:
if b is None:
b = 0
else:
b = int.from_bytes(b, 'big')
b = int.from_bytes(b, "big")
flags = (flags | b) & 0xFFFFFFFF
if flags != b:
config.set(_APP, _FLAGS, flags.to_bytes(4, 'big'))
config.set(_APP, _FLAGS, flags.to_bytes(4, "big"))
def get_autolock_delay_ms() -> int:
@ -140,13 +145,13 @@ def get_autolock_delay_ms() -> int:
if b is None:
return 10 * 60 * 1000
else:
return int.from_bytes(b, 'big')
return int.from_bytes(b, "big")
def set_autolock_delay_ms(delay_ms: int) -> None:
if delay_ms < 60 * 1000:
delay_ms = 60 * 1000
config.set(_APP, _AUTOLOCK_DELAY_MS, delay_ms.to_bytes(4, 'big'))
config.set(_APP, _AUTOLOCK_DELAY_MS, delay_ms.to_bytes(4, "big"))
def next_u2f_counter() -> int:
@ -154,13 +159,13 @@ def next_u2f_counter() -> int:
if b is None:
b = 0
else:
b = int.from_bytes(b, 'big') + 1
b = int.from_bytes(b, "big") + 1
set_u2f_counter(b)
return b
def set_u2f_counter(cntr: int):
config.set(_APP, _U2F_COUNTER, cntr.to_bytes(4, 'big'))
config.set(_APP, _U2F_COUNTER, cntr.to_bytes(4, "big"))
def wipe():

@ -1,6 +1,7 @@
if not __debug__:
from trezor.utils import halt
halt('debug mode inactive')
halt("debug mode inactive")
if __debug__:
from trezor import loop
@ -33,12 +34,16 @@ if __debug__:
m.reset_word_pos = reset_word_index
m.reset_entropy = reset_internal_entropy
if reset_current_words:
m.reset_word = ' '.join(reset_current_words)
m.reset_word = " ".join(reset_current_words)
return m
def boot():
# wipe storage when debug build is used
storage.wipe()
register(MessageType.DebugLinkDecision, protobuf_workflow, dispatch_DebugLinkDecision)
register(MessageType.DebugLinkGetState, protobuf_workflow, dispatch_DebugLinkGetState)
register(
MessageType.DebugLinkDecision, protobuf_workflow, dispatch_DebugLinkDecision
)
register(
MessageType.DebugLinkGetState, protobuf_workflow, dispatch_DebugLinkGetState
)

@ -9,21 +9,25 @@ from trezor.wire import protobuf_workflow, register
def dispatch_EthereumGetAddress(*args, **kwargs):
from .get_address import ethereum_get_address
return ethereum_get_address(*args, **kwargs)
def dispatch_EthereumSignTx(*args, **kwargs):
from .sign_tx import ethereum_sign_tx
return ethereum_sign_tx(*args, **kwargs)
def dispatch_EthereumSignMessage(*args, **kwargs):
from .sign_message import ethereum_sign_message
return ethereum_sign_message(*args, **kwargs)
def dispatch_EthereumVerifyMessage(*args, **kwargs):
from .verify_message import ethereum_verify_message
return ethereum_verify_message(*args, **kwargs)

@ -37,18 +37,18 @@ def _ethereum_address_hex(address, network=None):
hx = hexlify(address).decode()
prefix = str(network.chain_id) + '0x' if rskip60 else ''
prefix = str(network.chain_id) + "0x" if rskip60 else ""
hs = sha3_256(prefix + hx).digest(True)
h = ''
h = ""
for i in range(20):
l = hx[i * 2]
if hs[i] & 0x80 and l >= 'a' and l <= 'f':
if hs[i] & 0x80 and l >= "a" and l <= "f":
l = l.upper()
h += l
l = hx[i * 2 + 1]
if hs[i] & 0x08 and l >= 'a' and l <= 'f':
if hs[i] & 0x08 and l >= "a" and l <= "f":
l = l.upper()
h += l
return '0x' + h
return "0x" + h

@ -14,21 +14,23 @@ async def require_confirm_tx(ctx, to, value, chain_id, token=None, tx_type=None)
if to:
to_str = _ethereum_address_hex(to, networks.by_chain_id(chain_id))
else:
to_str = 'new contract?'
text = Text('Confirm sending', ui.ICON_SEND, icon_color=ui.GREEN)
to_str = "new contract?"
text = Text("Confirm sending", ui.ICON_SEND, icon_color=ui.GREEN)
text.bold(format_ethereum_amount(value, token, chain_id, tx_type))
text.normal('to')
text.normal("to")
text.mono(*split_address(to_str))
# we use SignTx, not ConfirmOutput, for compatibility with T1
await require_confirm(ctx, text, ButtonRequestType.SignTx)
async def require_confirm_fee(ctx, spending, gas_price, gas_limit, chain_id, token=None, tx_type=None):
text = Text('Confirm transaction', ui.ICON_SEND, icon_color=ui.GREEN)
async def require_confirm_fee(
ctx, spending, gas_price, gas_limit, chain_id, token=None, tx_type=None
):
text = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
text.bold(format_ethereum_amount(spending, token, chain_id, tx_type))
text.normal('Gas price:')
text.normal("Gas price:")
text.bold(format_ethereum_amount(gas_price, None, chain_id, tx_type))
text.normal('Maximum fee:')
text.normal("Maximum fee:")
text.bold(format_ethereum_amount(gas_price * gas_limit, None, chain_id, tx_type))
await require_hold_to_confirm(ctx, text, ButtonRequestType.SignTx)
@ -40,9 +42,9 @@ def split_data(data):
async def require_confirm_data(ctx, data, data_total):
data_str = hexlify(data[:36]).decode()
if data_total > 36:
data_str = data_str[:-2] + '..'
text = Text('Confirm data', ui.ICON_SEND, icon_color=ui.GREEN)
text.bold('Size: %d bytes' % data_total)
data_str = data_str[:-2] + ".."
text = Text("Confirm data", ui.ICON_SEND, icon_color=ui.GREEN)
text.bold("Size: %d bytes" % data_total)
text.mono(*split_data(data_str))
# we use SignTx, not ConfirmOutput, for compatibility with T1
await require_confirm(ctx, text, ButtonRequestType.SignTx)
@ -55,7 +57,7 @@ def split_address(address):
def format_ethereum_amount(value: int, token, chain_id: int, tx_type=None):
if token:
if token is tokens.UNKNOWN_TOKEN:
return 'Unknown token value'
return "Unknown token value"
suffix = token[2]
decimals = token[3]
else:
@ -63,7 +65,7 @@ def format_ethereum_amount(value: int, token, chain_id: int, tx_type=None):
decimals = 18
if value <= 1e9:
suffix = 'Wei ' + suffix
suffix = "Wei " + suffix
decimals = 0
return '%s %s' % (format_amount(value, decimals), suffix)
return "%s %s" % (format_amount(value, decimals), suffix)

@ -1,9 +1,9 @@
def shortcut_by_chain_id(chain_id, tx_type=None):
if tx_type in [1, 6] and chain_id in [1, 3]:
return 'WAN'
return "WAN"
else:
n = by_chain_id(chain_id)
return n.shortcut if n is not None else 'UNKN'
return n.shortcut if n is not None else "UNKN"
def by_chain_id(chain_id):
@ -21,14 +21,8 @@ def by_slip44(slip44):
class NetworkInfo:
def __init__(
self,
chain_id: int,
slip44: int,
shortcut: str,
name: str,
rskip60: bool
self, chain_id: int, slip44: int, shortcut: str, name: str, rskip60: bool
):
self.chain_id = chain_id
self.slip44 = slip44

@ -12,7 +12,7 @@ from apps.common.signverify import split_message
def message_digest(message):
h = HashWriter(sha3_256)
signed_message_header = '\x19Ethereum Signed Message:\n'
signed_message_header = "\x19Ethereum Signed Message:\n"
h.extend(signed_message_header)
h.extend(str(len(message)))
h.extend(message)
@ -37,6 +37,6 @@ async def ethereum_sign_message(ctx, msg):
async def require_confirm_sign_message(ctx, message):
message = split_message(message)
text = Text('Sign ETH message')
text = Text("Sign ETH message")
text.normal(*message)
await require_confirm(ctx, text)

@ -26,21 +26,32 @@ async def ethereum_sign_tx(ctx, msg):
# detect ERC - 20 token
token = None
recipient = msg.to
value = int.from_bytes(msg.value, 'big')
if len(msg.to) == 20 and \
len(msg.value) == 0 and \
data_total == 68 and \
len(msg.data_initial_chunk) == 68 and \
msg.data_initial_chunk[:16] == b'\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00':
value = int.from_bytes(msg.value, "big")
if (
len(msg.to) == 20
and len(msg.value) == 0
and data_total == 68
and len(msg.data_initial_chunk) == 68
and msg.data_initial_chunk[:16]
== b"\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
):
token = tokens.token_by_chain_address(msg.chain_id, msg.to)
recipient = msg.data_initial_chunk[16:36]
value = int.from_bytes(msg.data_initial_chunk[36:68], 'big')
value = int.from_bytes(msg.data_initial_chunk[36:68], "big")
await require_confirm_tx(ctx, recipient, value, msg.chain_id, token, msg.tx_type)
if token is None and msg.data_length > 0:
await require_confirm_data(ctx, msg.data_initial_chunk, data_total)
await require_confirm_fee(ctx, value, int.from_bytes(msg.gas_price, 'big'), int.from_bytes(msg.gas_limit, 'big'), msg.chain_id, token, msg.tx_type)
await require_confirm_fee(
ctx,
value,
int.from_bytes(msg.gas_price, "big"),
int.from_bytes(msg.gas_limit, "big"),
msg.chain_id,
token,
msg.tx_type,
)
data = bytearray()
data += msg.data_initial_chunk
@ -97,6 +108,7 @@ def get_total_length(msg: EthereumSignTx, data_total: int) -> int:
async def send_request_chunk(ctx, data_left: int):
from trezor.messages.MessageType import EthereumTxAck
# TODO: layoutProgress ?
req = EthereumTxRequest()
if data_left <= 1024:
@ -129,24 +141,24 @@ async def send_signature(ctx, msg: EthereumSignTx, digest):
def check(msg: EthereumSignTx):
if msg.tx_type not in [1, 6, None]:
raise wire.DataError('tx_type out of bounds')
raise wire.DataError("tx_type out of bounds")
if msg.chain_id < 0 or msg.chain_id > MAX_CHAIN_ID:
raise wire.DataError('chain_id out of bounds')
raise wire.DataError("chain_id out of bounds")
if msg.data_length > 0:
if not msg.data_initial_chunk:
raise wire.DataError('Data length provided, but no initial chunk')
raise wire.DataError("Data length provided, but no initial chunk")
# Our encoding only supports transactions up to 2^24 bytes. To
# prevent exceeding the limit we use a stricter limit on data length.
if msg.data_length > 16000000:
raise wire.DataError('Data length exceeds limit')
raise wire.DataError("Data length exceeds limit")
if len(msg.data_initial_chunk) > msg.data_length:
raise wire.DataError('Invalid size of initial chunk')
raise wire.DataError("Invalid size of initial chunk")
# safety checks
if not check_gas(msg) or not check_to(msg):
raise wire.DataError('Safety check failed')
raise wire.DataError("Safety check failed")
def check_gas(msg: EthereumSignTx) -> bool:
@ -159,7 +171,7 @@ def check_gas(msg: EthereumSignTx) -> bool:
def check_to(msg: EthereumTxRequest) -> bool:
if msg.to == b'':
if msg.to == b"":
if msg.data_length == 0:
# sending transaction to address 0 (contract creation) without a data field
return False
@ -171,15 +183,15 @@ def check_to(msg: EthereumTxRequest) -> bool:
def sanitize(msg):
if msg.value is None:
msg.value = b''
msg.value = b""
if msg.data_initial_chunk is None:
msg.data_initial_chunk = b''
msg.data_initial_chunk = b""
if msg.data_length is None:
msg.data_length = 0
if msg.to is None:
msg.to = b''
msg.to = b""
if msg.nonce is None:
msg.nonce = b''
msg.nonce = b""
if msg.chain_id is None:
msg.chain_id = 0
return msg

File diff suppressed because it is too large Load Diff

@ -18,25 +18,25 @@ async def ethereum_verify_message(ctx, msg):
pubkey = secp256k1.verify_recover(sig, digest)
if not pubkey:
raise ValueError('Invalid signature')
raise ValueError("Invalid signature")
pkh = sha3_256(pubkey[1:]).digest(True)[-20:]
if msg.address != pkh:
raise ValueError('Invalid signature')
raise ValueError("Invalid signature")
address = '0x' + hexlify(msg.address).decode()
address = "0x" + hexlify(msg.address).decode()
await require_confirm_verify_message(ctx, address, msg.message)
return Success(message='Message verified')
return Success(message="Message verified")
async def require_confirm_verify_message(ctx, address, message):
text = Text('Confirm address')
text = Text("Confirm address")
text.mono(*split_address(address))
await require_confirm(ctx, text)
text = Text('Verify message')
text = Text("Verify message")
text.mono(*split_message(message))
await require_confirm(ctx, text)

@ -18,17 +18,17 @@ _TYPE_INIT = const(0x80) # initial frame identifier
_TYPE_CONT = const(0x00) # continuation frame identifier
# types of cmd
_CMD_PING = const(0x81) # echo data through local processor only
_CMD_MSG = const(0x83) # send U2F message frame
_CMD_LOCK = const(0x84) # send lock channel command
_CMD_INIT = const(0x86) # channel initialization
_CMD_WINK = const(0x88) # send device identification wink
_CMD_PING = const(0x81) # echo data through local processor only
_CMD_MSG = const(0x83) # send U2F message frame
_CMD_LOCK = const(0x84) # send lock channel command
_CMD_INIT = const(0x86) # channel initialization
_CMD_WINK = const(0x88) # send device identification wink
_CMD_ERROR = const(0xbf) # error response
# types for the msg cmd
_MSG_REGISTER = const(0x01) # registration command
_MSG_REGISTER = const(0x01) # registration command
_MSG_AUTHENTICATE = const(0x02) # authenticate/sign command
_MSG_VERSION = const(0x03) # read version string command
_MSG_VERSION = const(0x03) # read version string command
# hid error codes
_ERR_NONE = const(0x00) # no error
@ -52,8 +52,8 @@ _SW_INS_NOT_SUPPORTED = const(0x6d00)
_SW_CLA_NOT_SUPPORTED = const(0x6e00)
# init response
_CAPFLAG_WINK = const(0x01) # device supports _CMD_WINK
_U2FHID_IF_VERSION = const(2) # interface version
_CAPFLAG_WINK = const(0x01) # device supports _CMD_WINK
_U2FHID_IF_VERSION = const(2) # interface version
# register response
_U2F_KEY_PATH = const(0x80553246)
@ -63,18 +63,18 @@ _U2F_ATT_CERT = b"0\x82\x01\x180\x81\xc0\x02\t\x00\xb1\xd9\x8fBdr\xd3,0\n\x06\x0
_BOGUS_APPID = b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
# authentication control byte
_AUTH_ENFORCE = const(0x03) # enforce user presence and sign
_AUTH_ENFORCE = const(0x03) # enforce user presence and sign
_AUTH_CHECK_ONLY = const(0x07) # check only
_AUTH_FLAG_TUP = const(0x01) # test of user presence set
_AUTH_FLAG_TUP = const(0x01) # test of user presence set
# common raw message format (ISO7816-4:2005 mapping)
_APDU_CLA = const(0) # uint8_t cla; // Class - reserved
_APDU_INS = const(1) # uint8_t ins; // U2F instruction
_APDU_P1 = const(2) # uint8_t p1; // U2F parameter 1
_APDU_P2 = const(3) # uint8_t p2; // U2F parameter 2
_APDU_LC1 = const(4) # uint8_t lc1; // Length field, set to zero
_APDU_LC2 = const(5) # uint8_t lc2; // Length field, MSB
_APDU_LC3 = const(6) # uint8_t lc3; // Length field, LSB
_APDU_CLA = const(0) # uint8_t cla; // Class - reserved
_APDU_INS = const(1) # uint8_t ins; // U2F instruction
_APDU_P1 = const(2) # uint8_t p1; // U2F parameter 1
_APDU_P2 = const(3) # uint8_t p2; // U2F parameter 2
_APDU_LC1 = const(4) # uint8_t lc1; // Length field, set to zero
_APDU_LC2 = const(5) # uint8_t lc2; // Length field, MSB
_APDU_LC3 = const(6) # uint8_t lc3; // Length field, LSB
_APDU_DATA = const(7) # uint8_t data[1]; // Data field
@ -85,10 +85,10 @@ def frame_init() -> dict:
# uint8_t bcntl; // Message byte count - low part
# uint8_t data[HID_RPT_SIZE - 7]; // Data payload
return {
'cid': 0 | uctypes.UINT32,
'cmd': 4 | uctypes.UINT8,
'bcnt': 5 | uctypes.UINT16,
'data': (7 | uctypes.ARRAY, (_HID_RPT_SIZE - 7) | uctypes.UINT8),
"cid": 0 | uctypes.UINT32,
"cmd": 4 | uctypes.UINT8,
"bcnt": 5 | uctypes.UINT16,
"data": (7 | uctypes.ARRAY, (_HID_RPT_SIZE - 7) | uctypes.UINT8),
}
@ -97,9 +97,9 @@ def frame_cont() -> dict:
# uint8_t seq; // Sequence number - b7 cleared
# uint8_t data[HID_RPT_SIZE - 5]; // Data payload
return {
'cid': 0 | uctypes.UINT32,
'seq': 4 | uctypes.UINT8,
'data': (5 | uctypes.ARRAY, (_HID_RPT_SIZE - 5) | uctypes.UINT8),
"cid": 0 | uctypes.UINT32,
"seq": 4 | uctypes.UINT8,
"data": (5 | uctypes.ARRAY, (_HID_RPT_SIZE - 5) | uctypes.UINT8),
}
@ -112,13 +112,13 @@ def resp_cmd_init() -> dict:
# uint8_t versionBuild; // Build version number
# uint8_t capFlags; // Capabilities flags
return {
'nonce': (0 | uctypes.ARRAY, 8 | uctypes.UINT8),
'cid': 8 | uctypes.UINT32,
'versionInterface': 12 | uctypes.UINT8,
'versionMajor': 13 | uctypes.UINT8,
'versionMinor': 14 | uctypes.UINT8,
'versionBuild': 15 | uctypes.UINT8,
'capFlags': 16 | uctypes.UINT8,
"nonce": (0 | uctypes.ARRAY, 8 | uctypes.UINT8),
"cid": 8 | uctypes.UINT32,
"versionInterface": 12 | uctypes.UINT8,
"versionMajor": 13 | uctypes.UINT8,
"versionMinor": 14 | uctypes.UINT8,
"versionBuild": 15 | uctypes.UINT8,
"capFlags": 16 | uctypes.UINT8,
}
@ -134,13 +134,13 @@ def resp_cmd_register(khlen: int, certlen: int, siglen: int) -> dict:
# uint8_t sig[siglen]; // Registration signature
# uint16_t status;
return {
'registerId': 0 | uctypes.UINT8,
'pubKey': (1 | uctypes.ARRAY, 65 | uctypes.UINT8),
'keyHandleLen': 66 | uctypes.UINT8,
'keyHandle': (67 | uctypes.ARRAY, khlen | uctypes.UINT8),
'cert': (cert_ofs | uctypes.ARRAY, certlen | uctypes.UINT8),
'sig': (sig_ofs | uctypes.ARRAY, siglen | uctypes.UINT8),
'status': status_ofs | uctypes.UINT16,
"registerId": 0 | uctypes.UINT8,
"pubKey": (1 | uctypes.ARRAY, 65 | uctypes.UINT8),
"keyHandleLen": 66 | uctypes.UINT8,
"keyHandle": (67 | uctypes.ARRAY, khlen | uctypes.UINT8),
"cert": (cert_ofs | uctypes.ARRAY, certlen | uctypes.UINT8),
"sig": (sig_ofs | uctypes.ARRAY, siglen | uctypes.UINT8),
"status": status_ofs | uctypes.UINT16,
}
@ -154,10 +154,10 @@ def req_cmd_authenticate(khlen: int) -> dict:
# uint8_t keyHandleLen; // Length of key handle
# uint8_t keyHandle[khlen]; // Key handle
return {
'chal': (0 | uctypes.ARRAY, 32 | uctypes.UINT8),
'appId': (32 | uctypes.ARRAY, 32 | uctypes.UINT8),
'keyHandleLen': 64 | uctypes.UINT8,
'keyHandle': (65 | uctypes.ARRAY, khlen | uctypes.UINT8),
"chal": (0 | uctypes.ARRAY, 32 | uctypes.UINT8),
"appId": (32 | uctypes.ARRAY, 32 | uctypes.UINT8),
"keyHandleLen": 64 | uctypes.UINT8,
"keyHandle": (65 | uctypes.ARRAY, khlen | uctypes.UINT8),
}
@ -168,17 +168,17 @@ def resp_cmd_authenticate(siglen: int) -> dict:
# uint8_t sig[siglen]; // Signature
# uint16_t status;
return {
'flags': 0 | uctypes.UINT8,
'ctr': 1 | uctypes.UINT32,
'sig': (5 | uctypes.ARRAY, siglen | uctypes.UINT8),
'status': status_ofs | uctypes.UINT16,
"flags": 0 | uctypes.UINT8,
"ctr": 1 | uctypes.UINT32,
"sig": (5 | uctypes.ARRAY, siglen | uctypes.UINT8),
"status": status_ofs | uctypes.UINT16,
}
def overlay_struct(buf, desc):
desc_size = uctypes.sizeof(desc, uctypes.BIG_ENDIAN)
if desc_size > len(buf):
raise ValueError('desc is too big (%d > %d)' % (desc_size, len(buf)))
raise ValueError("desc is too big (%d > %d)" % (desc_size, len(buf)))
return uctypes.struct(uctypes.addressof(buf), desc, uctypes.BIG_ENDIAN)
@ -189,8 +189,9 @@ def make_struct(desc):
class Msg:
def __init__(self, cid: int, cla: int, ins: int, p1: int, p2: int, lc: int, data: bytes) -> None:
def __init__(
self, cid: int, cla: int, ins: int, p1: int, p2: int, lc: int, data: bytes
) -> None:
self.cid = cid
self.cla = cla
self.ins = ins
@ -201,7 +202,6 @@ class Msg:
class Cmd:
def __init__(self, cid: int, cmd: int, data: bytes) -> None:
self.cid = cid
self.cmd = cmd
@ -212,10 +212,12 @@ class Cmd:
ins = self.data[_APDU_INS]
p1 = self.data[_APDU_P1]
p2 = self.data[_APDU_P2]
lc = (self.data[_APDU_LC1] << 16) + \
(self.data[_APDU_LC2] << 8) + \
(self.data[_APDU_LC3])
data = self.data[_APDU_DATA:_APDU_DATA + lc]
lc = (
(self.data[_APDU_LC1] << 16)
+ (self.data[_APDU_LC2] << 8)
+ (self.data[_APDU_LC3])
)
data = self.data[_APDU_DATA : _APDU_DATA + lc]
return Msg(self.cid, cla, ins, p1, p2, lc, data)
@ -235,7 +237,7 @@ async def read_cmd(iface: io.HID) -> Cmd:
if ifrm.cmd & _TYPE_MASK == _TYPE_CONT:
# unexpected cont packet, abort current msg
if __debug__:
log.warning(__name__, '_TYPE_CONT')
log.warning(__name__, "_TYPE_CONT")
return None
if datalen < bcnt:
@ -253,13 +255,13 @@ async def read_cmd(iface: io.HID) -> Cmd:
if cfrm.seq == _CMD_INIT:
# _CMD_INIT frame, cancels current channel
ifrm = overlay_struct(buf, desc_init)
data = ifrm.data[:ifrm.bcnt]
data = ifrm.data[: ifrm.bcnt]
break
if cfrm.cid != ifrm.cid:
# cont frame for a different channel, reply with BUSY and skip
if __debug__:
log.warning(__name__, '_ERR_CHANNEL_BUSY')
log.warning(__name__, "_ERR_CHANNEL_BUSY")
await send_cmd(cmd_error(cfrm.cid, _ERR_CHANNEL_BUSY), iface)
continue
@ -267,7 +269,7 @@ async def read_cmd(iface: io.HID) -> Cmd:
# cont frame for this channel, but incorrect seq number, abort
# current msg
if __debug__:
log.warning(__name__, '_ERR_INVALID_SEQ')
log.warning(__name__, "_ERR_INVALID_SEQ")
await send_cmd(cmd_error(cfrm.cid, _ERR_INVALID_SEQ), iface)
return None
@ -330,7 +332,6 @@ _CONFIRM_TIMEOUT_MS = const(10 * 1000)
class ConfirmState:
def __init__(self) -> None:
self.reset()
@ -382,19 +383,20 @@ class ConfirmState:
from trezor.ui.text import Text
if bytes(self.app_id) == _BOGUS_APPID:
text = Text('U2F mismatch', ui.ICON_WRONG, icon_color=ui.RED)
text.normal('Another U2F device', 'was used to register', 'in this application.')
text = Text("U2F mismatch", ui.ICON_WRONG, icon_color=ui.RED)
text.normal(
"Another U2F device", "was used to register", "in this application."
)
text.render()
await loop.sleep(3 * 1000 * 1000)
self.confirmed = True
else:
content = ConfirmContent(self.action, self.app_id)
dialog = ConfirmDialog(content, )
dialog = ConfirmDialog(content)
self.confirmed = await dialog == CONFIRMED
class ConfirmContent(ui.Widget):
def __init__(self, action: int, app_id: bytes) -> None:
self.action = action
self.app_id = app_id
@ -411,25 +413,30 @@ class ConfirmContent(ui.Widget):
if app_id == _BOGUS_APPID:
# TODO: display a warning dialog for bogus app ids
name = 'Another U2F device'
icon = res.load('apps/fido_u2f/res/u2f_generic.toif') # TODO: warning icon
name = "Another U2F device"
icon = res.load("apps/fido_u2f/res/u2f_generic.toif") # TODO: warning icon
elif app_id in knownapps.knownapps:
name = knownapps.knownapps[app_id]
try:
icon = res.load('apps/fido_u2f/res/u2f_%s.toif' % name.lower().replace(' ', '_'))
icon = res.load(
"apps/fido_u2f/res/u2f_%s.toif" % name.lower().replace(" ", "_")
)
except Exception:
icon = res.load('apps/fido_u2f/res/u2f_generic.toif')
icon = res.load("apps/fido_u2f/res/u2f_generic.toif")
else:
name = '%s...%s' % (hexlify(app_id[:4]).decode(), hexlify(app_id[-4:]).decode())
icon = res.load('apps/fido_u2f/res/u2f_generic.toif')
name = "%s...%s" % (
hexlify(app_id[:4]).decode(),
hexlify(app_id[-4:]).decode(),
)
icon = res.load("apps/fido_u2f/res/u2f_generic.toif")
self.app_name = name
self.app_icon = icon
def render(self) -> None:
if self.action == _CONFIRM_REGISTER:
header = 'U2F Register'
header = "U2F Register"
else:
header = 'U2F Authenticate'
header = "U2F Authenticate"
ui.header(header, ui.ICON_DEFAULT, ui.GREEN, ui.BG, ui.GREEN)
ui.display.image((ui.WIDTH - 64) // 2, 64, self.app_icon)
ui.display.text_center(ui.WIDTH // 2, 168, self.app_name, ui.MONO, ui.FG, ui.BG)
@ -441,46 +448,46 @@ def dispatch_cmd(req: Cmd, state: ConfirmState) -> Cmd:
if m.cla != 0:
if __debug__:
log.warning(__name__, '_SW_CLA_NOT_SUPPORTED')
log.warning(__name__, "_SW_CLA_NOT_SUPPORTED")
return msg_error(req.cid, _SW_CLA_NOT_SUPPORTED)
if m.lc + _APDU_DATA > len(req.data):
if __debug__:
log.warning(__name__, '_SW_WRONG_LENGTH')
log.warning(__name__, "_SW_WRONG_LENGTH")
return msg_error(req.cid, _SW_WRONG_LENGTH)
if m.ins == _MSG_REGISTER:
if __debug__:
log.debug(__name__, '_MSG_REGISTER')
log.debug(__name__, "_MSG_REGISTER")
return msg_register(m, state)
elif m.ins == _MSG_AUTHENTICATE:
if __debug__:
log.debug(__name__, '_MSG_AUTHENTICATE')
log.debug(__name__, "_MSG_AUTHENTICATE")
return msg_authenticate(m, state)
elif m.ins == _MSG_VERSION:
if __debug__:
log.debug(__name__, '_MSG_VERSION')
log.debug(__name__, "_MSG_VERSION")
return msg_version(m)
else:
if __debug__:
log.warning(__name__, '_SW_INS_NOT_SUPPORTED: %d', m.ins)
log.warning(__name__, "_SW_INS_NOT_SUPPORTED: %d", m.ins)
return msg_error(req.cid, _SW_INS_NOT_SUPPORTED)
elif req.cmd == _CMD_INIT:
if __debug__:
log.debug(__name__, '_CMD_INIT')
log.debug(__name__, "_CMD_INIT")
return cmd_init(req)
elif req.cmd == _CMD_PING:
if __debug__:
log.debug(__name__, '_CMD_PING')
log.debug(__name__, "_CMD_PING")
return req
elif req.cmd == _CMD_WINK:
if __debug__:
log.debug(__name__, '_CMD_WINK')
log.debug(__name__, "_CMD_WINK")
return req
else:
if __debug__:
log.warning(__name__, '_ERR_INVALID_CMD: %d', req.cmd)
log.warning(__name__, "_ERR_INVALID_CMD: %d", req.cmd)
return cmd_error(req.cid, _ERR_INVALID_CMD)
@ -510,13 +517,13 @@ def msg_register(req: Msg, state: ConfirmState) -> Cmd:
if not storage.is_initialized():
if __debug__:
log.warning(__name__, 'not initialized')
log.warning(__name__, "not initialized")
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# check length of input data
if len(req.data) != 64:
if __debug__:
log.warning(__name__, '_SW_WRONG_LENGTH req.data')
log.warning(__name__, "_SW_WRONG_LENGTH req.data")
return msg_error(req.cid, _SW_WRONG_LENGTH)
# parse challenge and app_id
@ -532,12 +539,12 @@ def msg_register(req: Msg, state: ConfirmState) -> Cmd:
# wait for a button or continue
if not state.confirmed:
if __debug__:
log.info(__name__, 'waiting for button')
log.info(__name__, "waiting for button")
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# sign the registration challenge and return
if __debug__:
log.info(__name__, 'signing register')
log.info(__name__, "signing register")
buf = msg_register_sign(chal, app_id)
state.reset()
@ -554,11 +561,11 @@ def msg_register_sign(challenge: bytes, app_id: bytes) -> bytes:
nodepath = [_U2F_KEY_PATH] + keypath
# prepare signing key from random path, compute decompressed public key
node = seed.derive_node_without_passphrase(nodepath, 'nist256p1')
node = seed.derive_node_without_passphrase(nodepath, "nist256p1")
pubkey = nist256p1.publickey(node.private_key(), False)
# first half of keyhandle is keypath
keybuf = ustruct.pack('>8L', *keypath)
keybuf = ustruct.pack(">8L", *keypath)
# second half of keyhandle is a hmac of app_id and keypath
keybase = hmac.Hmac(node.private_key(), app_id, hashlib.sha256)
@ -567,12 +574,12 @@ def msg_register_sign(challenge: bytes, app_id: bytes) -> bytes:
# hash the request data together with keyhandle and pubkey
dig = hashlib.sha256()
dig.update(b'\x00') # uint8_t reserved;
dig.update(app_id) # uint8_t appId[32];
dig.update(b"\x00") # uint8_t reserved;
dig.update(app_id) # uint8_t appId[32];
dig.update(challenge) # uint8_t chal[32];
dig.update(keybuf) # uint8_t keyHandle[64];
dig.update(keybuf) # uint8_t keyHandle[64];
dig.update(keybase)
dig.update(pubkey) # uint8_t pubKey[65];
dig.update(pubkey) # uint8_t pubKey[65];
dig = dig.digest()
# sign the digest and convert to der
@ -580,8 +587,9 @@ def msg_register_sign(challenge: bytes, app_id: bytes) -> bytes:
sig = der.encode_seq((sig[1:33], sig[33:]))
# pack to a response
buf, resp = make_struct(resp_cmd_register(
len(keybuf) + len(keybase), len(_U2F_ATT_CERT), len(sig)))
buf, resp = make_struct(
resp_cmd_register(len(keybuf) + len(keybase), len(_U2F_ATT_CERT), len(sig))
)
resp.registerId = _U2F_REGISTER_ID
utils.memcpy(resp.pubKey, 0, pubkey, 0, len(pubkey))
resp.keyHandleLen = len(keybuf) + len(keybase)
@ -599,20 +607,20 @@ def msg_authenticate(req: Msg, state: ConfirmState) -> Cmd:
if not storage.is_initialized():
if __debug__:
log.warning(__name__, 'not initialized')
log.warning(__name__, "not initialized")
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# we need at least keyHandleLen
if len(req.data) <= _REQ_CMD_AUTHENTICATE_KHLEN:
if __debug__:
log.warning(__name__, '_SW_WRONG_LENGTH req.data')
log.warning(__name__, "_SW_WRONG_LENGTH req.data")
return msg_error(req.cid, _SW_WRONG_LENGTH)
# check keyHandleLen
khlen = req.data[_REQ_CMD_AUTHENTICATE_KHLEN]
if khlen != 64:
if __debug__:
log.warning(__name__, '_SW_WRONG_LENGTH khlen')
log.warning(__name__, "_SW_WRONG_LENGTH khlen")
return msg_error(req.cid, _SW_WRONG_LENGTH)
auth = overlay_struct(req.data, req_cmd_authenticate(khlen))
@ -626,13 +634,13 @@ def msg_authenticate(req: Msg, state: ConfirmState) -> Cmd:
# if _AUTH_CHECK_ONLY is requested, return, because keyhandle has been checked already
if req.p1 == _AUTH_CHECK_ONLY:
if __debug__:
log.info(__name__, '_AUTH_CHECK_ONLY')
log.info(__name__, "_AUTH_CHECK_ONLY")
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# from now on, only _AUTH_ENFORCE is supported
if req.p1 != _AUTH_ENFORCE:
if __debug__:
log.info(__name__, '_AUTH_ENFORCE')
log.info(__name__, "_AUTH_ENFORCE")
return msg_error(req.cid, _SW_WRONG_DATA)
# check equality with last request
@ -644,12 +652,12 @@ def msg_authenticate(req: Msg, state: ConfirmState) -> Cmd:
# wait for a button or continue
if not state.confirmed:
if __debug__:
log.info(__name__, 'waiting for button')
log.info(__name__, "waiting for button")
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# sign the authentication challenge and return
if __debug__:
log.info(__name__, 'signing authentication')
log.info(__name__, "signing authentication")
buf = msg_authenticate_sign(auth.chal, auth.appId, node.private_key())
state.reset()
@ -662,18 +670,18 @@ def msg_authenticate_genkey(app_id: bytes, keyhandle: bytes):
# unpack the keypath from the first half of keyhandle
keybuf = keyhandle[:32]
keypath = ustruct.unpack('>8L', keybuf)
keypath = ustruct.unpack(">8L", keybuf)
# check high bit for hardened keys
for i in keypath:
if not i & HARDENED:
if __debug__:
log.warning(__name__, 'invalid key path')
log.warning(__name__, "invalid key path")
return None
# derive the signing key
nodepath = [_U2F_KEY_PATH] + list(keypath)
node = seed.derive_node_without_passphrase(nodepath, 'nist256p1')
node = seed.derive_node_without_passphrase(nodepath, "nist256p1")
# second half of keyhandle is a hmac of app_id and keypath
keybase = hmac.Hmac(node.private_key(), app_id, hashlib.sha256)
@ -683,7 +691,7 @@ def msg_authenticate_genkey(app_id: bytes, keyhandle: bytes):
# verify the hmac
if keybase != keyhandle[32:]:
if __debug__:
log.warning(__name__, 'invalid key handle')
log.warning(__name__, "invalid key handle")
return None
return node
@ -694,13 +702,13 @@ def msg_authenticate_sign(challenge: bytes, app_id: bytes, privkey: bytes) -> by
# get next counter
ctr = storage.next_u2f_counter()
ctrbuf = ustruct.pack('>L', ctr)
ctrbuf = ustruct.pack(">L", ctr)
# hash input data together with counter
dig = hashlib.sha256()
dig.update(app_id) # uint8_t appId[32];
dig.update(flags) # uint8_t flags;
dig.update(ctrbuf) # uint8_t ctr[4];
dig.update(app_id) # uint8_t appId[32];
dig.update(flags) # uint8_t flags;
dig.update(ctrbuf) # uint8_t ctr[4];
dig.update(challenge) # uint8_t chal[32];
dig = dig.digest()
@ -721,12 +729,12 @@ def msg_authenticate_sign(challenge: bytes, app_id: bytes, privkey: bytes) -> by
def msg_version(req: Msg) -> Cmd:
if req.data:
return msg_error(req.cid, _SW_WRONG_LENGTH)
return Cmd(req.cid, _CMD_MSG, b'U2F_V2\x90\x00') # includes _SW_NO_ERROR
return Cmd(req.cid, _CMD_MSG, b"U2F_V2\x90\x00") # includes _SW_NO_ERROR
def msg_error(cid: int, code: int) -> Cmd:
return Cmd(cid, _CMD_MSG, ustruct.pack('>H', code))
return Cmd(cid, _CMD_MSG, ustruct.pack(">H", code))
def cmd_error(cid: int, code: int) -> Cmd:
return Cmd(cid, _CMD_ERROR, ustruct.pack('>B', code))
return Cmd(cid, _CMD_ERROR, ustruct.pack(">B", code))

@ -1,19 +1,25 @@
from trezor.crypto import hashlib
knownapps = {
hashlib.sha256(b'https://account.gandi.net/api/u2f/trusted_facets.json').digest(): 'Gandi',
hashlib.sha256(b'https://api-9dcf9b83.duosecurity.com').digest(): 'Duo',
hashlib.sha256(b'https://bitbucket.org').digest(): 'Bitbucket',
hashlib.sha256(b'https://dashboard.stripe.com').digest(): 'Stripe',
hashlib.sha256(b'https://demo.yubico.com').digest(): 'Yubico U2F Demo',
hashlib.sha256(b'https://github.com/u2f/trusted_facets').digest(): 'GitHub',
hashlib.sha256(b'https://gitlab.com').digest(): 'GitLab',
hashlib.sha256(b'https://keepersecurity.com').digest(): 'Keeper',
hashlib.sha256(b'https://slushpool.com/static/security/u2f.json').digest(): 'Slush Pool',
hashlib.sha256(b'https://u2f.bin.coffee').digest(): 'u2f.bin.coffee checker',
hashlib.sha256(b'https://vault.bitwarden.com/app-id.json').digest(): 'bitwarden',
hashlib.sha256(b'https://www.bitfinex.com').digest(): 'Bitfinex',
hashlib.sha256(b'https://www.dropbox.com/u2f-app-id.json').digest(): 'Dropbox',
hashlib.sha256(b'https://www.fastmail.com').digest(): 'FastMail',
hashlib.sha256(b'https://www.gstatic.com/securitykey/origins.json').digest(): 'Google',
hashlib.sha256(
b"https://account.gandi.net/api/u2f/trusted_facets.json"
).digest(): "Gandi",
hashlib.sha256(b"https://api-9dcf9b83.duosecurity.com").digest(): "Duo",
hashlib.sha256(b"https://bitbucket.org").digest(): "Bitbucket",
hashlib.sha256(b"https://dashboard.stripe.com").digest(): "Stripe",
hashlib.sha256(b"https://demo.yubico.com").digest(): "Yubico U2F Demo",
hashlib.sha256(b"https://github.com/u2f/trusted_facets").digest(): "GitHub",
hashlib.sha256(b"https://gitlab.com").digest(): "GitLab",
hashlib.sha256(b"https://keepersecurity.com").digest(): "Keeper",
hashlib.sha256(
b"https://slushpool.com/static/security/u2f.json"
).digest(): "Slush Pool",
hashlib.sha256(b"https://u2f.bin.coffee").digest(): "u2f.bin.coffee checker",
hashlib.sha256(b"https://vault.bitwarden.com/app-id.json").digest(): "bitwarden",
hashlib.sha256(b"https://www.bitfinex.com").digest(): "Bitfinex",
hashlib.sha256(b"https://www.dropbox.com/u2f-app-id.json").digest(): "Dropbox",
hashlib.sha256(b"https://www.fastmail.com").digest(): "FastMail",
hashlib.sha256(
b"https://www.gstatic.com/securitykey/origins.json"
).digest(): "Google",
}

@ -9,15 +9,15 @@ from apps.common import cache, storage
def get_features():
f = Features()
f.vendor = 'trezor.io'
f.language = 'english'
f.major_version = utils.symbol('VERSION_MAJOR')
f.minor_version = utils.symbol('VERSION_MINOR')
f.patch_version = utils.symbol('VERSION_PATCH')
f.revision = utils.symbol('GITREV')
f.vendor = "trezor.io"
f.language = "english"
f.major_version = utils.symbol("VERSION_MAJOR")
f.minor_version = utils.symbol("VERSION_MINOR")
f.patch_version = utils.symbol("VERSION_PATCH")
f.revision = utils.symbol("GITREV")
f.model = utils.model()
if f.model == 'EMU':
f.model = 'T' # emulator currently emulates model T
if f.model == "EMU":
f.model = "T" # emulator currently emulates model T
f.device_id = storage.get_device_id()
f.label = storage.get_label()
f.initialized = storage.is_initialized()
@ -42,12 +42,12 @@ async def handle_GetFeatures(ctx, msg):
async def handle_Cancel(ctx, msg):
raise wire.ActionCancelled('Cancelled')
raise wire.ActionCancelled("Cancelled")
async def handle_ClearSession(ctx, msg):
cache.clear()
return Success(message='Session cleared')
return Success(message="Session cleared")
async def handle_Ping(ctx, msg):
@ -55,9 +55,11 @@ async def handle_Ping(ctx, msg):
from apps.common.confirm import require_confirm
from trezor.messages.ButtonRequestType import ProtectCall
from trezor.ui.text import Text
await require_confirm(ctx, Text('Confirm'), ProtectCall)
await require_confirm(ctx, Text("Confirm"), ProtectCall)
if msg.passphrase_protection:
from apps.common.request_passphrase import protect_by_passphrase
await protect_by_passphrase(ctx)
return Success(message=msg.message)

@ -14,26 +14,32 @@ async def homescreen():
def display_homescreen():
if not storage.is_initialized():
label = 'Go to trezor.io/start'
label = "Go to trezor.io/start"
image = None
else:
label = storage.get_label() or 'My TREZOR'
label = storage.get_label() or "My TREZOR"
image = storage.get_homescreen()
if not image:
image = res.load('apps/homescreen/res/bg.toif')
image = res.load("apps/homescreen/res/bg.toif")
if storage.is_initialized() and storage.unfinished_backup():
ui.display.bar(0, 0, ui.WIDTH, 30, ui.RED)
ui.display.text_center(ui.WIDTH // 2, 22, 'BACKUP FAILED!', ui.BOLD, ui.WHITE, ui.RED)
ui.display.text_center(
ui.WIDTH // 2, 22, "BACKUP FAILED!", ui.BOLD, ui.WHITE, ui.RED
)
ui.display.bar(0, 30, ui.WIDTH, ui.HEIGHT - 30, ui.BG)
elif storage.is_initialized() and storage.needs_backup():
ui.display.bar(0, 0, ui.WIDTH, 30, ui.YELLOW)
ui.display.text_center(ui.WIDTH // 2, 22, 'NEEDS BACKUP!', ui.BOLD, ui.BLACK, ui.YELLOW)
ui.display.text_center(
ui.WIDTH // 2, 22, "NEEDS BACKUP!", ui.BOLD, ui.BLACK, ui.YELLOW
)
ui.display.bar(0, 30, ui.WIDTH, ui.HEIGHT - 30, ui.BG)
elif storage.is_initialized() and not config.has_pin():
ui.display.bar(0, 0, ui.WIDTH, 30, ui.YELLOW)
ui.display.text_center(ui.WIDTH // 2, 22, 'PIN NOT SET!', ui.BOLD, ui.BLACK, ui.YELLOW)
ui.display.text_center(
ui.WIDTH // 2, 22, "PIN NOT SET!", ui.BOLD, ui.BLACK, ui.YELLOW
)
ui.display.bar(0, 30, ui.WIDTH, ui.HEIGHT - 30, ui.BG)
else:
ui.display.bar(0, 0, ui.WIDTH, ui.HEIGHT, ui.BG)

@ -10,26 +10,31 @@ from trezor.wire import protobuf_workflow, register
def dispatch_LiskGetAddress(*args, **kwargs):
from .get_address import layout_lisk_get_address
return layout_lisk_get_address(*args, **kwargs)
def dispatch_LiskGetPublicKey(*args, **kwargs):
from .get_public_key import lisk_get_public_key
return lisk_get_public_key(*args, **kwargs)
def dispatch_LiskSignTx(*args, **kwargs):
from .sign_tx import lisk_sign_tx
return lisk_sign_tx(*args, **kwargs)
def dispatch_LiskSignMessage(*args, **kwargs):
from .sign_message import lisk_sign_message
return lisk_sign_message(*args, **kwargs)
def dispatch_LiskVerifyMessage(*args, **kwargs):
from .verify_message import lisk_verify_message
return lisk_verify_message(*args, **kwargs)

@ -1,18 +1,18 @@
from trezor.crypto.hashlib import sha256
LISK_CURVE = 'ed25519'
LISK_CURVE = "ed25519"
def get_address_from_public_key(pubkey):
pubkeyhash = sha256(pubkey).digest()
address = int.from_bytes(pubkeyhash[:8], 'little')
return str(address) + 'L'
address = int.from_bytes(pubkeyhash[:8], "little")
return str(address) + "L"
def get_votes_count(votes):
plus, minus = 0, 0
for vote in votes:
if vote.startswith('+'):
if vote.startswith("+"):
plus += 1
else:
minus += 1
@ -23,11 +23,11 @@ def get_vote_tx_text(votes):
plus, minus = get_votes_count(votes)
text = []
if plus > 0:
text.append(_text_with_plural('Add', plus))
text.append(_text_with_plural("Add", plus))
if minus > 0:
text.append(_text_with_plural('Remove', minus))
text.append(_text_with_plural("Remove", minus))
return text
def _text_with_plural(txt, value):
return '%s %s %s' % (txt, value, ('votes' if value != 1 else 'vote'))
return "%s %s %s" % (txt, value, ("votes" if value != 1 else "vote"))

@ -10,23 +10,23 @@ from apps.wallet.get_public_key import _show_pubkey
async def require_confirm_tx(ctx, to, value):
text = Text('Confirm sending', ui.ICON_SEND, icon_color=ui.GREEN)
text = Text("Confirm sending", ui.ICON_SEND, icon_color=ui.GREEN)
text.bold(format_amount(value))
text.normal('to')
text.normal("to")
text.mono(*split_address(to))
return await require_confirm(ctx, text, ButtonRequestType.SignTx)
async def require_confirm_delegate_registration(ctx, delegate_name):
text = Text('Confirm transaction', ui.ICON_SEND, icon_color=ui.GREEN)
text.normal('Do you really want to')
text.normal('register a delegate?')
text = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
text.normal("Do you really want to")
text.normal("register a delegate?")
text.bold(*chunks(delegate_name, 20))
return await require_confirm(ctx, text, ButtonRequestType.SignTx)
async def require_confirm_vote_tx(ctx, votes):
text = Text('Confirm transaction', ui.ICON_SEND, icon_color=ui.GREEN)
text = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
text.normal(*get_vote_tx_text(votes))
return await require_confirm(ctx, text, ButtonRequestType.SignTx)
@ -36,23 +36,23 @@ async def require_confirm_public_key(ctx, public_key):
async def require_confirm_multisig(ctx, multisignature):
text = Text('Confirm transaction', ui.ICON_SEND, icon_color=ui.GREEN)
text.normal('Keys group length: %s' % len(multisignature.keys_group))
text.normal('Life time: %s' % multisignature.life_time)
text.normal('Min: %s' % multisignature.min)
text = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
text.normal("Keys group length: %s" % len(multisignature.keys_group))
text.normal("Life time: %s" % multisignature.life_time)
text.normal("Min: %s" % multisignature.min)
return await require_confirm(ctx, text, ButtonRequestType.SignTx)
async def require_confirm_fee(ctx, value, fee):
text = Text('Confirm transaction', ui.ICON_SEND, icon_color=ui.GREEN)
text = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
text.bold(format_amount(value))
text.normal('fee:')
text.normal("fee:")
text.bold(format_amount(fee))
await require_hold_to_confirm(ctx, text, ButtonRequestType.ConfirmOutput)
def format_amount(value):
return '%s LSK' % (int(value) / 100000000)
return "%s LSK" % (int(value) / 100000000)
def split_address(address):

@ -14,7 +14,7 @@ from apps.wallet.sign_tx.signing import write_varint
def message_digest(message):
h = HashWriter(sha256)
signed_message_header = 'Lisk Signed Message:\n'
signed_message_header = "Lisk Signed Message:\n"
write_varint(h, len(signed_message_header))
h.extend(signed_message_header)
write_varint(h, len(message))
@ -39,6 +39,6 @@ async def lisk_sign_message(ctx, msg):
async def require_confirm_sign_message(ctx, message):
text = Text('Sign Lisk message')
text = Text("Sign Lisk message")
text.normal(*split_message(message))
await require_confirm(ctx, text)

@ -20,7 +20,7 @@ async def lisk_sign_tx(ctx, msg):
try:
await _require_confirm_by_type(ctx, transaction)
except AttributeError:
raise wire.DataError('The transaction has invalid asset data field')
raise wire.DataError("The transaction has invalid asset data field")
await layout.require_confirm_fee(ctx, transaction.amount, transaction.fee)
@ -61,48 +61,59 @@ async def _require_confirm_by_type(ctx, transaction):
if transaction.type == LiskTransactionType.Transfer:
return await layout.require_confirm_tx(
ctx, transaction.recipient_id, transaction.amount)
ctx, transaction.recipient_id, transaction.amount
)
if transaction.type == LiskTransactionType.RegisterDelegate:
return await layout.require_confirm_delegate_registration(
ctx, transaction.asset.delegate.username)
ctx, transaction.asset.delegate.username
)
if transaction.type == LiskTransactionType.CastVotes:
return await layout.require_confirm_vote_tx(
ctx, transaction.asset.votes)
return await layout.require_confirm_vote_tx(ctx, transaction.asset.votes)
if transaction.type == LiskTransactionType.RegisterSecondPassphrase:
return await layout.require_confirm_public_key(
ctx, transaction.asset.signature.public_key)
ctx, transaction.asset.signature.public_key
)
if transaction.type == LiskTransactionType.RegisterMultisignatureAccount:
return await layout.require_confirm_multisig(
ctx, transaction.asset.multisignature)
ctx, transaction.asset.multisignature
)
raise wire.DataError('Invalid transaction type')
raise wire.DataError("Invalid transaction type")
def _get_transaction_bytes(tx):
# Required transaction parameters
t_type = ustruct.pack('<b', tx.type)
t_timestamp = ustruct.pack('<i', tx.timestamp)
t_type = ustruct.pack("<b", tx.type)
t_timestamp = ustruct.pack("<i", tx.timestamp)
t_sender_public_key = tx.sender_public_key
t_requester_public_key = tx.requester_public_key or b''
t_requester_public_key = tx.requester_public_key or b""
if not tx.recipient_id:
# Value can be empty string
t_recipient_id = ustruct.pack('>Q', 0)
t_recipient_id = ustruct.pack(">Q", 0)
else:
# Lisk uses big-endian for recipient_id, string -> int -> bytes
t_recipient_id = ustruct.pack('>Q', int(tx.recipient_id[:-1]))
t_recipient_id = ustruct.pack(">Q", int(tx.recipient_id[:-1]))
t_amount = ustruct.pack('<Q', tx.amount)
t_amount = ustruct.pack("<Q", tx.amount)
t_asset = _get_asset_data_bytes(tx)
t_signature = tx.signature or b''
t_signature = tx.signature or b""
return (t_type, t_timestamp, t_sender_public_key, t_requester_public_key,
t_recipient_id, t_amount, t_asset, t_signature)
return (
t_type,
t_timestamp,
t_sender_public_key,
t_requester_public_key,
t_recipient_id,
t_amount,
t_asset,
t_signature,
)
def _get_asset_data_bytes(msg):
@ -110,24 +121,24 @@ def _get_asset_data_bytes(msg):
if msg.type == LiskTransactionType.Transfer:
# Transfer transaction have optional data field
if msg.asset.data is not None:
return bytes(msg.asset.data, 'utf8')
return bytes(msg.asset.data, "utf8")
else:
return b''
return b""
if msg.type == LiskTransactionType.RegisterDelegate:
return bytes(msg.asset.delegate.username, 'utf8')
return bytes(msg.asset.delegate.username, "utf8")
if msg.type == LiskTransactionType.CastVotes:
return bytes(''.join(msg.asset.votes), 'utf8')
return bytes("".join(msg.asset.votes), "utf8")
if msg.type == LiskTransactionType.RegisterSecondPassphrase:
return msg.asset.signature.public_key
if msg.type == LiskTransactionType.RegisterMultisignatureAccount:
data = b''
data += ustruct.pack('<b', msg.asset.multisignature.min)
data += ustruct.pack('<b', msg.asset.multisignature.life_time)
data += bytes(''.join(msg.asset.multisignature.keys_group), 'utf8')
data = b""
data += ustruct.pack("<b", msg.asset.multisignature.min)
data += ustruct.pack("<b", msg.asset.multisignature.life_time)
data += bytes("".join(msg.asset.multisignature.keys_group), "utf8")
return data
raise wire.DataError('Invalid transaction type')
raise wire.DataError("Invalid transaction type")

@ -12,9 +12,9 @@ async def lisk_verify_message(ctx, msg):
digest = message_digest(msg.message)
verified = ed25519.verify(msg.public_key, msg.signature, digest)
if not verified:
raise wire.ProcessError('Invalid signature')
raise wire.ProcessError("Invalid signature")
address = get_address_from_public_key(msg.public_key)
await require_confirm_verify_message(ctx, address, msg.message)
return Success(message='Message verified')
return Success(message="Message verified")

@ -14,46 +14,55 @@ from trezor.wire import protobuf_workflow, register
def dispatch_LoadDevice(*args, **kwargs):
from .load_device import load_device
return load_device(*args, **kwargs)
def dispatch_ResetDevice(*args, **kwargs):
from .reset_device import reset_device
return reset_device(*args, **kwargs)
def dispatch_BackupDevice(*args, **kwargs):
from .backup_device import backup_device
return backup_device(*args, **kwargs)
def dispatch_WipeDevice(*args, **kwargs):
from .wipe_device import wipe_device
return wipe_device(*args, **kwargs)
def dispatch_RecoveryDevice(*args, **kwargs):
from .recovery_device import recovery_device
return recovery_device(*args, **kwargs)
def dispatch_ApplySettings(*args, **kwargs):
from .apply_settings import apply_settings
return apply_settings(*args, **kwargs)
def dispatch_ApplyFlags(*args, **kwargs):
from .apply_flags import apply_flags
return apply_flags(*args, **kwargs)
def dispatch_ChangePin(*args, **kwargs):
from .change_pin import change_pin
return change_pin(*args, **kwargs)
def dispatch_SetU2FCounter(*args, **kwargs):
from .set_u2f_counter import set_u2f_counter
return set_u2f_counter(*args, **kwargs)

@ -5,4 +5,4 @@ from apps.common import storage
async def apply_flags(ctx, msg):
storage.set_flags(msg.flags)
return Success(message='Flags applied')
return Success(message="Flags applied")

@ -8,12 +8,17 @@ from apps.common.confirm import require_confirm
async def apply_settings(ctx, msg):
if msg.homescreen is None and msg.label is None and msg.use_passphrase is None and msg.passphrase_source is None:
raise wire.ProcessError('No setting provided')
if (
msg.homescreen is None
and msg.label is None
and msg.use_passphrase is None
and msg.passphrase_source is None
):
raise wire.ProcessError("No setting provided")
if msg.homescreen is not None:
if len(msg.homescreen) > storage.HOMESCREEN_MAXSIZE:
raise wire.DataError('Homescreen is too complex')
raise wire.DataError("Homescreen is too complex")
await require_confirm_change_homescreen(ctx)
if msg.label is not None:
@ -25,43 +30,45 @@ async def apply_settings(ctx, msg):
if msg.passphrase_source is not None:
await require_confirm_change_passphrase_source(ctx, msg.passphrase_source)
storage.load_settings(label=msg.label,
use_passphrase=msg.use_passphrase,
homescreen=msg.homescreen,
passphrase_source=msg.passphrase_source)
storage.load_settings(
label=msg.label,
use_passphrase=msg.use_passphrase,
homescreen=msg.homescreen,
passphrase_source=msg.passphrase_source,
)
return Success(message='Settings applied')
return Success(message="Settings applied")
async def require_confirm_change_homescreen(ctx):
text = Text('Change homescreen', ui.ICON_CONFIG)
text.normal('Do you really want to', 'change homescreen?')
text = Text("Change homescreen", ui.ICON_CONFIG)
text.normal("Do you really want to", "change homescreen?")
await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall)
async def require_confirm_change_label(ctx, label):
text = Text('Change label', ui.ICON_CONFIG)
text.normal('Do you really want to', 'change label to')
text.bold('%s?' % label)
text = Text("Change label", ui.ICON_CONFIG)
text.normal("Do you really want to", "change label to")
text.bold("%s?" % label)
await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall)
async def require_confirm_change_passphrase(ctx, use):
text = Text('Enable passphrase' if use else 'Disable passphrase', ui.ICON_CONFIG)
text.normal('Do you really want to')
text.normal('enable passphrase' if use else 'disable passphrase')
text.normal('encryption?')
text = Text("Enable passphrase" if use else "Disable passphrase", ui.ICON_CONFIG)
text.normal("Do you really want to")
text.normal("enable passphrase" if use else "disable passphrase")
text.normal("encryption?")
await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall)
async def require_confirm_change_passphrase_source(ctx, source):
if source == PassphraseSourceType.DEVICE:
desc = 'ON DEVICE'
desc = "ON DEVICE"
elif source == PassphraseSourceType.HOST:
desc = 'ON HOST'
desc = "ON HOST"
else:
desc = 'ASK'
text = Text('Passphrase source', ui.ICON_CONFIG)
text.normal('Do you really want to', 'change the passphrase', 'source to')
text.bold('ALWAYS %s?' % desc)
desc = "ASK"
text = Text("Passphrase source", ui.ICON_CONFIG)
text.normal("Do you really want to", "change the passphrase", "source to")
text.bold("ALWAYS %s?" % desc)
await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall)

@ -12,9 +12,9 @@ from apps.management.reset_device import (
async def backup_device(ctx, msg):
if not storage.is_initialized():
raise wire.ProcessError('Device is not initialized')
raise wire.ProcessError("Device is not initialized")
if not storage.needs_backup():
raise wire.ProcessError('Seed already backed up')
raise wire.ProcessError("Seed already backed up")
mnemonic = storage.get_mnemonic()
@ -33,4 +33,4 @@ async def backup_device(ctx, msg):
storage.set_unfinished_backup(False)
return Success(message='Seed successfully backed up')
return Success(message="Seed successfully backed up")

@ -18,52 +18,52 @@ async def change_pin(ctx, msg):
if config.has_pin():
curpin = await request_pin_ack(ctx)
if not config.check_pin(pin_to_int(curpin), show_pin_timeout):
raise wire.PinInvalid('PIN invalid')
raise wire.PinInvalid("PIN invalid")
else:
curpin = ''
curpin = ""
# get new pin
if not msg.remove:
newpin = await request_pin_confirm(ctx)
else:
newpin = ''
newpin = ""
# write into storage
if not config.change_pin(pin_to_int(curpin), pin_to_int(newpin), show_pin_timeout):
raise wire.PinInvalid('PIN invalid')
raise wire.PinInvalid("PIN invalid")
if newpin:
return Success(message='PIN changed')
return Success(message="PIN changed")
else:
return Success(message='PIN removed')
return Success(message="PIN removed")
def require_confirm_change_pin(ctx, msg):
has_pin = config.has_pin()
if msg.remove and has_pin: # removing pin
text = Text('Remove PIN', ui.ICON_CONFIG)
text.normal('Do you really want to')
text.bold('remove current PIN?')
text = Text("Remove PIN", ui.ICON_CONFIG)
text.normal("Do you really want to")
text.bold("remove current PIN?")
return require_confirm(ctx, text)
if not msg.remove and has_pin: # changing pin
text = Text('Remove PIN', ui.ICON_CONFIG)
text.normal('Do you really want to')
text.bold('change current PIN?')
text = Text("Remove PIN", ui.ICON_CONFIG)
text.normal("Do you really want to")
text.bold("change current PIN?")
return require_confirm(ctx, text)
if not msg.remove and not has_pin: # setting new pin
text = Text('Remove PIN', ui.ICON_CONFIG)
text.normal('Do you really want to')
text.bold('set new PIN?')
text = Text("Remove PIN", ui.ICON_CONFIG)
text.normal("Do you really want to")
text.bold("set new PIN?")
return require_confirm(ctx, text)
async def request_pin_confirm(ctx, *args, **kwargs):
while True:
pin1 = await request_pin_ack(ctx, 'Enter new PIN', *args, **kwargs)
pin2 = await request_pin_ack(ctx, 'Re-enter new PIN', *args, **kwargs)
pin1 = await request_pin_ack(ctx, "Enter new PIN", *args, **kwargs)
pin2 = await request_pin_ack(ctx, "Re-enter new PIN", *args, **kwargs)
if pin1 == pin2:
return pin1
await pin_mismatch()
@ -71,17 +71,19 @@ async def request_pin_confirm(ctx, *args, **kwargs):
async def request_pin_ack(ctx, *args, **kwargs):
try:
await ctx.call(ButtonRequest(code=ButtonRequestType.Other), MessageType.ButtonAck)
await ctx.call(
ButtonRequest(code=ButtonRequestType.Other), MessageType.ButtonAck
)
return await ctx.wait(request_pin(*args, **kwargs))
except PinCancelled:
raise wire.ActionCancelled('Cancelled')
raise wire.ActionCancelled("Cancelled")
@ui.layout
async def pin_mismatch():
text = Text('PIN mismatch', ui.ICON_WRONG, icon_color=ui.RED)
text.normal('Entered PINs do not', 'match each other.')
text.normal('')
text.normal('Please, try again...')
text = Text("PIN mismatch", ui.ICON_WRONG, icon_color=ui.RED)
text.normal("Entered PINs do not", "match each other.")
text.normal("")
text.normal("Please, try again...")
text.render()
await loop.sleep(3 * 1000 * 1000)

@ -11,24 +11,22 @@ from apps.common.confirm import require_confirm
async def load_device(ctx, msg):
if storage.is_initialized():
raise wire.UnexpectedMessage('Already initialized')
raise wire.UnexpectedMessage("Already initialized")
if msg.node is not None:
raise wire.ProcessError('LoadDevice.node is not supported')
raise wire.ProcessError("LoadDevice.node is not supported")
if not msg.skip_checksum and not bip39.check(msg.mnemonic):
raise wire.ProcessError('Mnemonic is not valid')
raise wire.ProcessError("Mnemonic is not valid")
text = Text('Loading seed')
text.bold('Loading private seed', 'is not recommended.')
text.normal('Continue only if you', 'know what you are doing!')
text = Text("Loading seed")
text.bold("Loading private seed", "is not recommended.")
text.normal("Continue only if you", "know what you are doing!")
await require_confirm(ctx, text)
storage.load_mnemonic(
mnemonic=msg.mnemonic, needs_backup=True)
storage.load_settings(
use_passphrase=msg.passphrase_protection, label=msg.label)
storage.load_mnemonic(mnemonic=msg.mnemonic, needs_backup=True)
storage.load_settings(use_passphrase=msg.passphrase_protection, label=msg.label)
if msg.pin:
config.change_pin(pin_to_int(''), pin_to_int(msg.pin), None)
config.change_pin(pin_to_int(""), pin_to_int(msg.pin), None)
return Success(message='Device loaded')
return Success(message="Device loaded")

@ -15,7 +15,7 @@ from apps.management.change_pin import request_pin_confirm
async def recovery_device(ctx, msg):
'''
"""
Recover BIP39 seed into empty device.
1. Ask for the number of words in recovered seed.
@ -23,9 +23,9 @@ async def recovery_device(ctx, msg):
3. Optionally check the seed validity.
4. Optionally ask for the PIN, with confirmation.
5. Save into storage.
'''
"""
if not msg.dry_run and storage.is_initialized():
raise wire.UnexpectedMessage('Already initialized')
raise wire.UnexpectedMessage("Already initialized")
# ask for the number of words
wordcount = await request_wordcount(ctx)
@ -36,7 +36,7 @@ async def recovery_device(ctx, msg):
# check mnemonic validity
if msg.enforce_wordlist or msg.dry_run:
if not bip39.check(mnemonic):
raise wire.ProcessError('Mnemonic is not valid')
raise wire.ProcessError("Mnemonic is not valid")
# ask for pin repeatedly
if msg.pin_protection:
@ -45,25 +45,27 @@ async def recovery_device(ctx, msg):
# save into storage
if not msg.dry_run:
if msg.pin_protection:
config.change_pin(pin_to_int(''), pin_to_int(newpin), None)
storage.load_settings(
label=msg.label, use_passphrase=msg.passphrase_protection)
storage.load_mnemonic(
mnemonic=mnemonic, needs_backup=False)
return Success(message='Device recovered')
config.change_pin(pin_to_int(""), pin_to_int(newpin), None)
storage.load_settings(label=msg.label, use_passphrase=msg.passphrase_protection)
storage.load_mnemonic(mnemonic=mnemonic, needs_backup=False)
return Success(message="Device recovered")
else:
if storage.get_mnemonic() == mnemonic:
return Success(message='The seed is valid and matches the one in the device')
return Success(
message="The seed is valid and matches the one in the device"
)
else:
raise wire.ProcessError('The seed is valid but does not match the one in the device')
raise wire.ProcessError(
"The seed is valid but does not match the one in the device"
)
@ui.layout
async def request_wordcount(ctx):
await ctx.call(ButtonRequest(code=MnemonicWordCount), ButtonAck)
text = Text('Device recovery', ui.ICON_RECOVERY)
text.normal('Number of words?')
text = Text("Device recovery", ui.ICON_RECOVERY)
text.normal("Number of words?")
count = await ctx.wait(WordSelector(text))
return count
@ -76,8 +78,8 @@ async def request_mnemonic(ctx, count: int) -> str:
words = []
board = MnemonicKeyboard()
for i in range(count):
board.prompt = 'Type the %s word:' % format_ordinal(i + 1)
board.prompt = "Type the %s word:" % format_ordinal(i + 1)
word = await ctx.wait(board)
words.append(word)
return ' '.join(words)
return " ".join(words)

@ -25,15 +25,15 @@ if __debug__:
async def reset_device(ctx, msg):
# validate parameters and device state
if msg.strength not in (128, 192, 256):
raise wire.ProcessError('Invalid strength (has to be 128, 192 or 256 bits)')
raise wire.ProcessError("Invalid strength (has to be 128, 192 or 256 bits)")
if storage.is_initialized():
raise wire.UnexpectedMessage('Already initialized')
raise wire.UnexpectedMessage("Already initialized")
# request new PIN
if msg.pin_protection:
newpin = await request_pin_confirm(ctx)
else:
newpin = ''
newpin = ""
# generate and display internal entropy
internal_ent = random.bytes(32)
@ -58,14 +58,12 @@ async def reset_device(ctx, msg):
await show_wrong_entry(ctx)
# write PIN into storage
if not config.change_pin(pin_to_int(''), pin_to_int(newpin), None):
raise wire.ProcessError('Could not change PIN')
if not config.change_pin(pin_to_int(""), pin_to_int(newpin), None):
raise wire.ProcessError("Could not change PIN")
# write settings and mnemonic into storage
storage.load_settings(
label=msg.label, use_passphrase=msg.passphrase_protection)
storage.load_mnemonic(
mnemonic=mnemonic, needs_backup=msg.skip_backup)
storage.load_settings(label=msg.label, use_passphrase=msg.passphrase_protection)
storage.load_mnemonic(mnemonic=mnemonic, needs_backup=msg.skip_backup)
# show success message. if we skipped backup, it's possible that homescreen
# is still running, uninterrupted. restart it to pick up new label.
@ -74,78 +72,64 @@ async def reset_device(ctx, msg):
else:
workflow.restartdefault()
return Success(message='Initialized')
return Success(message="Initialized")
def generate_mnemonic(strength: int,
int_entropy: bytes,
ext_entropy: bytes) -> bytes:
def generate_mnemonic(strength: int, int_entropy: bytes, ext_entropy: bytes) -> bytes:
ehash = hashlib.sha256()
ehash.update(int_entropy)
ehash.update(ext_entropy)
entropy = ehash.digest()
mnemonic = bip39.from_data(entropy[:strength // 8])
mnemonic = bip39.from_data(entropy[: strength // 8])
return mnemonic
async def show_warning(ctx):
text = Text('Backup your seed', ui.ICON_NOCOPY)
text = Text("Backup your seed", ui.ICON_NOCOPY)
text.normal(
'Never make a digital',
'copy of your recovery',
'seed and never upload',
'it online!')
"Never make a digital",
"copy of your recovery",
"seed and never upload",
"it online!",
)
await require_confirm(
ctx,
text,
ButtonRequestType.ResetDevice,
confirm='I understand',
cancel=None)
ctx, text, ButtonRequestType.ResetDevice, confirm="I understand", cancel=None
)
async def show_wrong_entry(ctx):
text = Text('Wrong entry!', ui.ICON_WRONG, icon_color=ui.RED)
text.normal(
'You have entered',
'wrong seed word.',
'Please check again.')
text = Text("Wrong entry!", ui.ICON_WRONG, icon_color=ui.RED)
text.normal("You have entered", "wrong seed word.", "Please check again.")
await require_confirm(
ctx,
text,
ButtonRequestType.ResetDevice,
confirm='Check again',
cancel=None)
ctx, text, ButtonRequestType.ResetDevice, confirm="Check again", cancel=None
)
async def show_success(ctx):
text = Text('Backup is done!', ui.ICON_CONFIRM, icon_color=ui.GREEN)
text = Text("Backup is done!", ui.ICON_CONFIRM, icon_color=ui.GREEN)
text.normal(
'Never make a digital',
'copy of your recovery',
'seed and never upload',
'it online!')
"Never make a digital",
"copy of your recovery",
"seed and never upload",
"it online!",
)
await require_confirm(
ctx,
text,
ButtonRequestType.ResetDevice,
confirm='Finish setup',
cancel=None)
ctx, text, ButtonRequestType.ResetDevice, confirm="Finish setup", cancel=None
)
async def show_entropy(ctx, entropy: bytes):
entropy_str = hexlify(entropy).decode()
lines = chunks(entropy_str, 16)
text = Text('Internal entropy', ui.ICON_RESET)
text = Text("Internal entropy", ui.ICON_RESET)
text.mono(*lines)
await require_confirm(
ctx,
text,
ButtonRequestType.ResetDevice)
await require_confirm(ctx, text, ButtonRequestType.ResetDevice)
async def show_mnemonic(ctx, mnemonic: str):
await ctx.call(
ButtonRequest(code=ButtonRequestType.ResetDevice), MessageType.ButtonAck)
ButtonRequest(code=ButtonRequestType.ResetDevice), MessageType.ButtonAck
)
first_page = const(0)
words_per_page = const(4)
words = list(enumerate(mnemonic.split()))
@ -159,8 +143,8 @@ async def show_mnemonic_page(page: int, page_count: int, pages: list):
if __debug__:
debug.reset_current_words = [word for _, word in pages[page]]
lines = ['%2d. %s' % (wi + 1, word) for wi, word in pages[page]]
text = Text('Recovery seed', ui.ICON_RESET)
lines = ["%2d. %s" % (wi + 1, word) for wi, word in pages[page]]
text = Text("Recovery seed", ui.ICON_RESET)
text.mono(*lines)
content = Scrollpage(text, page, page_count)
@ -192,6 +176,6 @@ async def check_word(ctx, words: list, index: int):
if __debug__:
debug.reset_word_index = index
keyboard = MnemonicKeyboard('Type the %s word:' % format_ordinal(index + 1))
keyboard = MnemonicKeyboard("Type the %s word:" % format_ordinal(index + 1))
result = await ctx.wait(keyboard)
return result == words[index]

@ -9,13 +9,13 @@ from apps.common.confirm import require_confirm
async def set_u2f_counter(ctx, msg):
if msg.u2f_counter is None:
raise wire.ProcessError('No value provided')
raise wire.ProcessError("No value provided")
text = Text('Set U2F counter', ui.ICON_CONFIG)
text.normal('Do you really want to', 'set the U2F counter')
text.bold('to %d?' % msg.u2f_counter)
text = Text("Set U2F counter", ui.ICON_CONFIG)
text.normal("Do you really want to", "set the U2F counter")
text.bold("to %d?" % msg.u2f_counter)
await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall)
storage.set_u2f_counter(msg.u2f_counter)
return Success(message='U2F counter set')
return Success(message="U2F counter set")

@ -9,15 +9,18 @@ from apps.common.confirm import require_hold_to_confirm
async def wipe_device(ctx, msg):
text = Text('Wipe device', ui.ICON_WIPE, icon_color=ui.RED)
text.normal('Do you really want to', 'wipe the device?', '')
text.bold('All data will be lost.')
text = Text("Wipe device", ui.ICON_WIPE, icon_color=ui.RED)
text.normal("Do you really want to", "wipe the device?", "")
text.bold("All data will be lost.")
await require_hold_to_confirm(ctx, text,
await require_hold_to_confirm(
ctx,
text,
code=ButtonRequestType.WipeDevice,
button_style=ui.BTN_CANCEL,
loader_style=ui.LDR_DANGER)
loader_style=ui.LDR_DANGER,
)
storage.wipe()
return Success(message='Device wiped')
return Success(message="Device wiped")

@ -4,11 +4,13 @@ from trezor.wire import protobuf_workflow, register
def dispatch_NemGetAddress(*args, **kwargs):
from .get_address import get_address
return get_address(*args, **kwargs)
def dispatch_NemSignTx(*args, **kwargs):
from .signing import sign_tx
return sign_tx(*args, **kwargs)

@ -30,7 +30,9 @@ async def get_address(ctx, msg):
async def _show_address(ctx, address: str, network: int):
lines = split_address(address)
text = Text('Confirm address', ui.ICON_RECEIVE, icon_color=ui.GREEN)
text.normal('%s network' % get_network_str(network))
text = Text("Confirm address", ui.ICON_RECEIVE, icon_color=ui.GREEN)
text.normal("%s network" % get_network_str(network))
text.mono(*lines)
return await confirm(ctx, text, code=ButtonRequestType.Address, cancel='QR', cancel_style=ui.BTN_KEY)
return await confirm(
ctx, text, code=ButtonRequestType.Address, cancel="QR", cancel_style=ui.BTN_KEY
)

@ -3,7 +3,7 @@ from micropython import const
NEM_NETWORK_MAINNET = const(0x68)
NEM_NETWORK_TESTNET = const(0x98)
NEM_NETWORK_MIJIN = const(0x60)
NEM_CURVE = 'ed25519-keccak'
NEM_CURVE = "ed25519-keccak"
NEM_TRANSACTION_TYPE_TRANSFER = const(0x0101)
NEM_TRANSACTION_TYPE_IMPORTANCE_TRANSFER = const(0x0801)
@ -19,7 +19,7 @@ NEM_MAX_SUPPLY = const(9000000000)
NEM_SALT_SIZE = const(32)
AES_BLOCK_SIZE = const(16)
NEM_HASH_ALG = 'keccak'
NEM_HASH_ALG = "keccak"
NEM_PUBLIC_KEY_SIZE = const(32) # ed25519 public key
NEM_LEVY_PERCENTILE_DIVISOR_ABSOLUTE = const(10000)
NEM_MOSAIC_AMOUNT_DIVISOR = const(1000000)
@ -30,8 +30,8 @@ NEM_MAX_ENCRYPTED_PAYLOAD_SIZE = const(960)
def get_network_str(network: int) -> str:
if network == NEM_NETWORK_MAINNET:
return 'Mainnet'
return "Mainnet"
elif network == NEM_NETWORK_TESTNET:
return 'Testnet'
return "Testnet"
elif network == NEM_NETWORK_MIJIN:
return 'Mijin'
return "Mijin"

@ -10,15 +10,17 @@ from apps.common.confirm import require_confirm, require_hold_to_confirm
async def require_confirm_text(ctx, action: str):
words = split_words(action, 18)
await require_confirm_content(ctx, 'Confirm action', words)
await require_confirm_content(ctx, "Confirm action", words)
async def require_confirm_fee(ctx, action: str, fee: int):
content = (
ui.NORMAL, action,
ui.BOLD, '%s XEM' % format_amount(fee, NEM_MAX_DIVISIBILITY),
ui.NORMAL,
action,
ui.BOLD,
"%s XEM" % format_amount(fee, NEM_MAX_DIVISIBILITY),
)
await require_confirm_content(ctx, 'Confirm fee', content)
await require_confirm_content(ctx, "Confirm fee", content)
async def require_confirm_content(ctx, headline: str, content: list):
@ -28,10 +30,10 @@ async def require_confirm_content(ctx, headline: str, content: list):
async def require_confirm_final(ctx, fee: int):
text = Text('Final confirm', ui.ICON_SEND, icon_color=ui.GREEN)
text.normal('Sign this transaction')
text.bold('and pay %s XEM' % format_amount(fee, NEM_MAX_DIVISIBILITY))
text.normal('for network fee?')
text = Text("Final confirm", ui.ICON_SEND, icon_color=ui.GREEN)
text.normal("Sign this transaction")
text.bold("and pay %s XEM" % format_amount(fee, NEM_MAX_DIVISIBILITY))
text.normal("for network fee?")
# we use SignTx, not ConfirmOutput, for compatibility with T1
await require_hold_to_confirm(ctx, text, ButtonRequestType.SignTx)
@ -42,5 +44,5 @@ def split_address(address: str):
def trim(payload: str, length: int) -> str:
if len(payload) > length:
return payload[:length] + '..'
return payload[:length] + ".."
return payload

@ -5,11 +5,15 @@ from trezor.messages.NEMTransactionCommon import NEMTransactionCommon
from . import layout, serialize
async def mosaic_creation(ctx, public_key: bytes, common: NEMTransactionCommon, creation: NEMMosaicCreation) -> bytearray:
async def mosaic_creation(
ctx, public_key: bytes, common: NEMTransactionCommon, creation: NEMMosaicCreation
) -> bytearray:
await layout.ask_mosaic_creation(ctx, common, creation)
return serialize.serialize_mosaic_creation(common, creation, public_key)
async def supply_change(ctx, public_key: bytes, common: NEMTransactionCommon, change: NEMMosaicSupplyChange) -> bytearray:
async def supply_change(
ctx, public_key: bytes, common: NEMTransactionCommon, change: NEMMosaicSupplyChange
) -> bytearray:
await layout.ask_supply_change(ctx, common, change)
return serialize.serialize_mosaic_supply_change(common, change, public_key)

@ -24,39 +24,55 @@ from ..layout import (
)
async def ask_mosaic_creation(ctx, common: NEMTransactionCommon, creation: NEMMosaicCreation):
await require_confirm_content(ctx, 'Create mosaic', _creation_message(creation))
async def ask_mosaic_creation(
ctx, common: NEMTransactionCommon, creation: NEMMosaicCreation
):
await require_confirm_content(ctx, "Create mosaic", _creation_message(creation))
await _require_confirm_properties(ctx, creation.definition)
await require_confirm_fee(ctx, 'Confirm creation fee', creation.fee)
await require_confirm_fee(ctx, "Confirm creation fee", creation.fee)
await require_confirm_final(ctx, common.fee)
async def ask_supply_change(ctx, common: NEMTransactionCommon, change: NEMMosaicSupplyChange):
await require_confirm_content(ctx, 'Supply change', _supply_message(change))
async def ask_supply_change(
ctx, common: NEMTransactionCommon, change: NEMMosaicSupplyChange
):
await require_confirm_content(ctx, "Supply change", _supply_message(change))
if change.type == NEMSupplyChangeType.SupplyChange_Decrease:
msg = 'Decrease supply by ' + str(change.delta) + ' whole units?'
msg = "Decrease supply by " + str(change.delta) + " whole units?"
elif change.type == NEMSupplyChangeType.SupplyChange_Increase:
msg = 'Increase supply by ' + str(change.delta) + ' whole units?'
msg = "Increase supply by " + str(change.delta) + " whole units?"
else:
raise ValueError('Invalid supply change type')
raise ValueError("Invalid supply change type")
await require_confirm_text(ctx, msg)
await require_confirm_final(ctx, common.fee)
def _creation_message(mosaic_creation):
return [ui.NORMAL, 'Create mosaic',
ui.BOLD, mosaic_creation.definition.mosaic,
ui.NORMAL, 'under namespace',
ui.BOLD, mosaic_creation.definition.namespace]
return [
ui.NORMAL,
"Create mosaic",
ui.BOLD,
mosaic_creation.definition.mosaic,
ui.NORMAL,
"under namespace",
ui.BOLD,
mosaic_creation.definition.namespace,
]
def _supply_message(supply_change):
return [ui.NORMAL, 'Modify supply for',
ui.BOLD, supply_change.mosaic,
ui.NORMAL, 'under namespace',
ui.BOLD, supply_change.namespace]
return [
ui.NORMAL,
"Modify supply for",
ui.BOLD,
supply_change.mosaic,
ui.NORMAL,
"under namespace",
ui.BOLD,
supply_change.namespace,
]
async def _require_confirm_properties(ctx, definition: NEMMosaicDefinition):
@ -81,64 +97,64 @@ def _get_mosaic_properties(definition: NEMMosaicDefinition):
# description
if definition.description:
t = Text('Confirm properties', ui.ICON_SEND)
t.bold('Description:')
t = Text("Confirm properties", ui.ICON_SEND)
t.bold("Description:")
t.normal(*split_words(trim(definition.description, 70), 22))
properties.append(t)
# transferable
if definition.transferable:
transferable = 'Yes'
transferable = "Yes"
else:
transferable = 'No'
t = Text('Confirm properties', ui.ICON_SEND)
t.bold('Transferable?')
transferable = "No"
t = Text("Confirm properties", ui.ICON_SEND)
t.bold("Transferable?")
t.normal(transferable)
properties.append(t)
# mutable_supply
if definition.mutable_supply:
imm = 'mutable'
imm = "mutable"
else:
imm = 'immutable'
imm = "immutable"
if definition.supply:
t = Text('Confirm properties', ui.ICON_SEND)
t.bold('Initial supply:')
t = Text("Confirm properties", ui.ICON_SEND)
t.bold("Initial supply:")
t.normal(str(definition.supply), imm)
else:
t = Text('Confirm properties', ui.ICON_SEND)
t.bold('Initial supply:')
t = Text("Confirm properties", ui.ICON_SEND)
t.bold("Initial supply:")
t.normal(imm)
properties.append(t)
# levy
if definition.levy:
t = Text('Confirm properties', ui.ICON_SEND)
t.bold('Levy recipient:')
t = Text("Confirm properties", ui.ICON_SEND)
t.bold("Levy recipient:")
t.mono(*split_address(definition.levy_address))
properties.append(t)
t = Text('Confirm properties', ui.ICON_SEND)
t.bold('Levy fee:')
t = Text("Confirm properties", ui.ICON_SEND)
t.bold("Levy fee:")
t.normal(str(definition.fee))
t.bold('Levy divisibility:')
t.bold("Levy divisibility:")
t.normal(str(definition.divisibility))
properties.append(t)
t = Text('Confirm properties', ui.ICON_SEND)
t.bold('Levy namespace:')
t = Text("Confirm properties", ui.ICON_SEND)
t.bold("Levy namespace:")
t.normal(definition.levy_namespace)
t.bold('Levy mosaic:')
t.bold("Levy mosaic:")
t.normal(definition.levy_mosaic)
properties.append(t)
if definition.levy == NEMMosaicLevy.MosaicLevy_Absolute:
levy_type = 'absolute'
levy_type = "absolute"
else:
levy_type = 'percentile'
t = Text('Confirm properties', ui.ICON_SEND)
t.bold('Levy type:')
levy_type = "percentile"
t = Text("Confirm properties", ui.ICON_SEND)
t.bold("Levy type:")
t.normal(levy_type)
properties.append(t)

@ -9,30 +9,51 @@ from ..helpers import (
from ..writers import write_bytes_with_length, write_common, write_uint32, write_uint64
def serialize_mosaic_creation(common: NEMTransactionCommon, creation: NEMMosaicCreation, public_key: bytes):
w = write_common(common, bytearray(public_key), NEM_TRANSACTION_TYPE_MOSAIC_CREATION)
def serialize_mosaic_creation(
common: NEMTransactionCommon, creation: NEMMosaicCreation, public_key: bytes
):
w = write_common(
common, bytearray(public_key), NEM_TRANSACTION_TYPE_MOSAIC_CREATION
)
mosaics_w = bytearray()
write_bytes_with_length(mosaics_w, bytearray(public_key))
identifier_length = 4 + len(creation.definition.namespace) + 4 + len(creation.definition.mosaic)
identifier_length = (
4 + len(creation.definition.namespace) + 4 + len(creation.definition.mosaic)
)
write_uint32(mosaics_w, identifier_length)
write_bytes_with_length(mosaics_w, bytearray(creation.definition.namespace))
write_bytes_with_length(mosaics_w, bytearray(creation.definition.mosaic))
write_bytes_with_length(mosaics_w, bytearray(creation.definition.description))
write_uint32(mosaics_w, 4) # number of properties
_write_property(mosaics_w, 'divisibility', creation.definition.divisibility)
_write_property(mosaics_w, 'initialSupply', creation.definition.supply)
_write_property(mosaics_w, 'supplyMutable', creation.definition.mutable_supply)
_write_property(mosaics_w, 'transferable', creation.definition.transferable)
_write_property(mosaics_w, "divisibility", creation.definition.divisibility)
_write_property(mosaics_w, "initialSupply", creation.definition.supply)
_write_property(mosaics_w, "supplyMutable", creation.definition.mutable_supply)
_write_property(mosaics_w, "transferable", creation.definition.transferable)
if creation.definition.levy:
levy_identifier_length = 4 + len(creation.definition.levy_namespace) + 4 + len(creation.definition.levy_mosaic)
write_uint32(mosaics_w, 4 + 4 + len(creation.definition.levy_address) + 4 + levy_identifier_length + 8)
levy_identifier_length = (
4
+ len(creation.definition.levy_namespace)
+ 4
+ len(creation.definition.levy_mosaic)
)
write_uint32(
mosaics_w,
4
+ 4
+ len(creation.definition.levy_address)
+ 4
+ levy_identifier_length
+ 8,
)
write_uint32(mosaics_w, creation.definition.levy)
write_bytes_with_length(mosaics_w, bytearray(creation.definition.levy_address))
write_uint32(mosaics_w, levy_identifier_length)
write_bytes_with_length(mosaics_w, bytearray(creation.definition.levy_namespace))
write_bytes_with_length(
mosaics_w, bytearray(creation.definition.levy_namespace)
)
write_bytes_with_length(mosaics_w, bytearray(creation.definition.levy_mosaic))
write_uint64(mosaics_w, creation.definition.fee)
else:
@ -47,8 +68,12 @@ def serialize_mosaic_creation(common: NEMTransactionCommon, creation: NEMMosaicC
return w
def serialize_mosaic_supply_change(common: NEMTransactionCommon, change: NEMMosaicSupplyChange, public_key: bytes):
w = write_common(common, bytearray(public_key), NEM_TRANSACTION_TYPE_MOSAIC_SUPPLY_CHANGE)
def serialize_mosaic_supply_change(
common: NEMTransactionCommon, change: NEMMosaicSupplyChange, public_key: bytes
):
w = write_common(
common, bytearray(public_key), NEM_TRANSACTION_TYPE_MOSAIC_SUPPLY_CHANGE
)
identifier_length = 4 + len(change.namespace) + 4 + len(change.mosaic)
write_uint32(w, identifier_length)
@ -62,19 +87,19 @@ def serialize_mosaic_supply_change(common: NEMTransactionCommon, change: NEMMosa
def _write_property(w: bytearray, name: str, value):
if value is None:
if name in ('divisibility', 'initialSupply'):
if name in ("divisibility", "initialSupply"):
value = 0
elif name in ('supplyMutable', 'transferable'):
elif name in ("supplyMutable", "transferable"):
value = False
if type(value) == bool:
if value:
value = 'true'
value = "true"
else:
value = 'false'
value = "false"
elif type(value) == int:
value = str(value)
elif type(value) != str:
raise ValueError('Incompatible value type')
raise ValueError("Incompatible value type")
write_uint32(w, 4 + len(name) + 4 + len(value))
write_bytes_with_length(w, bytearray(name))
write_bytes_with_length(w, bytearray(value))

@ -13,15 +13,19 @@ def initiate(public_key, common: NEMTransactionCommon, inner_tx: bytes) -> bytes
return serialize.serialize_multisig(common, public_key, inner_tx)
def cosign(public_key, common: NEMTransactionCommon, inner_tx: bytes, signer: bytes) -> bytes:
def cosign(
public_key, common: NEMTransactionCommon, inner_tx: bytes, signer: bytes
) -> bytes:
return serialize.serialize_multisig_signature(common, public_key, inner_tx, signer)
async def aggregate_modification(ctx,
public_key: bytes,
common: NEMTransactionCommon,
aggr: NEMAggregateModification,
multisig: bool):
async def aggregate_modification(
ctx,
public_key: bytes,
common: NEMTransactionCommon,
aggr: NEMAggregateModification,
multisig: bool,
):
await layout.ask_aggregate_modification(ctx, common, aggr, multisig)
w = serialize.serialize_aggregate_modification(common, aggr, public_key)

@ -21,36 +21,38 @@ from ..layout import (
async def ask_multisig(ctx, msg: NEMSignTx):
address = nem.compute_address(msg.multisig.signer, msg.transaction.network)
if msg.cosigning:
await _require_confirm_address(ctx, 'Cosign transaction for', address)
await _require_confirm_address(ctx, "Cosign transaction for", address)
else:
await _require_confirm_address(ctx, 'Initiate transaction for', address)
await require_confirm_fee(ctx, 'Confirm multisig fee', msg.transaction.fee)
await _require_confirm_address(ctx, "Initiate transaction for", address)
await require_confirm_fee(ctx, "Confirm multisig fee", msg.transaction.fee)
async def ask_aggregate_modification(ctx, common: NEMTransactionCommon, mod: NEMAggregateModification, multisig: bool):
async def ask_aggregate_modification(
ctx, common: NEMTransactionCommon, mod: NEMAggregateModification, multisig: bool
):
if not multisig:
await require_confirm_text(ctx, 'Convert account to multisig account?')
await require_confirm_text(ctx, "Convert account to multisig account?")
for m in mod.modifications:
if m.type == NEMModificationType.CosignatoryModification_Add:
action = 'Add'
action = "Add"
else:
action = 'Remove'
action = "Remove"
address = nem.compute_address(m.public_key, common.network)
await _require_confirm_address(ctx, action + ' cosignatory', address)
await _require_confirm_address(ctx, action + " cosignatory", address)
if mod.relative_change:
if multisig:
action = 'Modify the number of cosignatories by '
action = "Modify the number of cosignatories by "
else:
action = 'Set minimum cosignatories to '
await require_confirm_text(ctx, action + str(mod.relative_change) + '?')
action = "Set minimum cosignatories to "
await require_confirm_text(ctx, action + str(mod.relative_change) + "?")
await require_confirm_final(ctx, common.fee)
async def _require_confirm_address(ctx, action: str, address: str):
text = Text('Confirm address', ui.ICON_SEND, icon_color=ui.GREEN)
text = Text("Confirm address", ui.ICON_SEND, icon_color=ui.GREEN)
text.normal(action)
text.mono(*split_address(address))
await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput)

@ -16,10 +16,16 @@ def serialize_multisig(common: NEMTransactionCommon, public_key: bytes, inner: b
return w
def serialize_multisig_signature(common: NEMTransactionCommon, public_key: bytes,
inner: bytes, address_public_key: bytes):
def serialize_multisig_signature(
common: NEMTransactionCommon,
public_key: bytes,
inner: bytes,
address_public_key: bytes,
):
address = nem.compute_address(address_public_key, common.network)
w = write_common(common, bytearray(public_key), NEM_TRANSACTION_TYPE_MULTISIG_SIGNATURE)
w = write_common(
common, bytearray(public_key), NEM_TRANSACTION_TYPE_MULTISIG_SIGNATURE
)
digest = hashlib.sha3_256(inner).digest(True)
write_uint32(w, 4 + len(digest))
@ -28,20 +34,26 @@ def serialize_multisig_signature(common: NEMTransactionCommon, public_key: bytes
return w
def serialize_aggregate_modification(common: NEMTransactionCommon, mod: NEMAggregateModification, public_key: bytes):
def serialize_aggregate_modification(
common: NEMTransactionCommon, mod: NEMAggregateModification, public_key: bytes
):
version = common.network << 24 | 1
if mod.relative_change:
version = common.network << 24 | 2
w = write_common(common,
bytearray(public_key),
NEM_TRANSACTION_TYPE_AGGREGATE_MODIFICATION,
version)
w = write_common(
common,
bytearray(public_key),
NEM_TRANSACTION_TYPE_AGGREGATE_MODIFICATION,
version,
)
write_uint32(w, len(mod.modifications))
return w
def serialize_cosignatory_modification(w: bytearray, type: int, cosignatory_pubkey: bytes):
def serialize_cosignatory_modification(
w: bytearray, type: int, cosignatory_pubkey: bytes
):
write_uint32(w, 4 + 4 + len(cosignatory_pubkey))
write_uint32(w, type)
write_bytes_with_length(w, bytearray(cosignatory_pubkey))

@ -4,6 +4,11 @@ from trezor.messages.NEMTransactionCommon import NEMTransactionCommon
from . import layout, serialize
async def namespace(ctx, public_key: bytes, common: NEMTransactionCommon, namespace: NEMProvisionNamespace) -> bytearray:
async def namespace(
ctx,
public_key: bytes,
common: NEMTransactionCommon,
namespace: NEMProvisionNamespace,
) -> bytearray:
await layout.ask_provision_namespace(ctx, common, namespace)
return serialize.serialize_provision_namespace(common, namespace, public_key)

@ -4,18 +4,25 @@ from trezor.messages import NEMProvisionNamespace, NEMTransactionCommon
from ..layout import require_confirm_content, require_confirm_fee, require_confirm_final
async def ask_provision_namespace(ctx, common: NEMTransactionCommon, namespace: NEMProvisionNamespace):
async def ask_provision_namespace(
ctx, common: NEMTransactionCommon, namespace: NEMProvisionNamespace
):
if namespace.parent:
content = (ui.NORMAL, 'Create namespace',
ui.BOLD, namespace.namespace,
ui.NORMAL, 'under namespace',
ui.BOLD, namespace.parent)
await require_confirm_content(ctx, 'Confirm namespace', content)
content = (
ui.NORMAL,
"Create namespace",
ui.BOLD,
namespace.namespace,
ui.NORMAL,
"under namespace",
ui.BOLD,
namespace.parent,
)
await require_confirm_content(ctx, "Confirm namespace", content)
else:
content = (ui.NORMAL, 'Create namespace',
ui.BOLD, namespace.namespace)
await require_confirm_content(ctx, 'Confirm namespace', content)
content = (ui.NORMAL, "Create namespace", ui.BOLD, namespace.namespace)
await require_confirm_content(ctx, "Confirm namespace", content)
await require_confirm_fee(ctx, 'Confirm rental fee', namespace.fee)
await require_confirm_fee(ctx, "Confirm rental fee", namespace.fee)
await require_confirm_final(ctx, common.fee)

@ -5,10 +5,12 @@ from ..helpers import NEM_TRANSACTION_TYPE_PROVISION_NAMESPACE
from ..writers import write_bytes_with_length, write_common, write_uint32, write_uint64
def serialize_provision_namespace(common: NEMTransactionCommon, namespace: NEMProvisionNamespace, public_key: bytes) -> bytearray:
tx = write_common(common,
bytearray(public_key),
NEM_TRANSACTION_TYPE_PROVISION_NAMESPACE)
def serialize_provision_namespace(
common: NEMTransactionCommon, namespace: NEMProvisionNamespace, public_key: bytes
) -> bytearray:
tx = write_common(
common, bytearray(public_key), NEM_TRANSACTION_TYPE_PROVISION_NAMESPACE
)
write_bytes_with_length(tx, bytearray(namespace.sink))
write_uint64(tx, namespace.fee)

@ -30,16 +30,26 @@ async def sign_tx(ctx, msg: NEMSignTx):
elif msg.supply_change:
tx = await mosaic.supply_change(ctx, public_key, common, msg.supply_change)
elif msg.aggregate_modification:
tx = await multisig.aggregate_modification(ctx, public_key, common, msg.aggregate_modification, msg.multisig is not None)
tx = await multisig.aggregate_modification(
ctx,
public_key,
common,
msg.aggregate_modification,
msg.multisig is not None,
)
elif msg.importance_transfer:
tx = await transfer.importance_transfer(ctx, public_key, common, msg.importance_transfer)
tx = await transfer.importance_transfer(
ctx, public_key, common, msg.importance_transfer
)
else:
raise ValueError('No transaction provided')
raise ValueError("No transaction provided")
if msg.multisig:
# wrap transaction in multisig wrapper
if msg.cosigning:
tx = multisig.cosign(_get_public_key(node), msg.transaction, tx, msg.multisig.signer)
tx = multisig.cosign(
_get_public_key(node), msg.transaction, tx, msg.multisig.signer
)
else:
tx = multisig.initiate(_get_public_key(node), msg.transaction, tx)

@ -5,7 +5,9 @@ from trezor.messages.NEMTransfer import NEMTransfer
from . import layout, serialize
async def transfer(ctx, public_key: bytes, common: NEMTransactionCommon, transfer: NEMTransfer, node):
async def transfer(
ctx, public_key: bytes, common: NEMTransactionCommon, transfer: NEMTransfer, node
):
transfer.mosaics = serialize.canonicalize_mosaics(transfer.mosaics)
payload, encrypted = serialize.get_transfer_payload(transfer, node)
@ -17,6 +19,8 @@ async def transfer(ctx, public_key: bytes, common: NEMTransactionCommon, transfe
return w
async def importance_transfer(ctx, public_key: bytes, common: NEMTransactionCommon, imp: NEMImportanceTransfer):
async def importance_transfer(
ctx, public_key: bytes, common: NEMTransactionCommon, imp: NEMImportanceTransfer
):
await layout.ask_importance_transfer(ctx, common, imp)
return serialize.serialize_importance_transfer(common, imp, public_key)

@ -22,7 +22,13 @@ from ..mosaic.helpers import get_mosaic_definition, is_nem_xem_mosaic
from apps.common.confirm import require_confirm
async def ask_transfer(ctx, common: NEMTransactionCommon, transfer: NEMTransfer, payload: bytes, encrypted: bool):
async def ask_transfer(
ctx,
common: NEMTransactionCommon,
transfer: NEMTransfer,
payload: bytes,
encrypted: bool,
):
if payload:
await _require_confirm_payload(ctx, transfer.payload, encrypted)
for mosaic in transfer.mosaics:
@ -31,7 +37,9 @@ async def ask_transfer(ctx, common: NEMTransactionCommon, transfer: NEMTransfer,
await require_confirm_final(ctx, common.fee)
async def ask_transfer_mosaic(ctx, common: NEMTransactionCommon, transfer: NEMTransfer, mosaic: NEMMosaic):
async def ask_transfer_mosaic(
ctx, common: NEMTransactionCommon, transfer: NEMTransfer, mosaic: NEMMosaic
):
if is_nem_xem_mosaic(mosaic.namespace, mosaic.mosaic):
return
@ -39,31 +47,38 @@ async def ask_transfer_mosaic(ctx, common: NEMTransactionCommon, transfer: NEMTr
mosaic_quantity = mosaic.quantity * transfer.amount / NEM_MOSAIC_AMOUNT_DIVISOR
if definition:
msg = Text('Confirm mosaic', ui.ICON_SEND, icon_color=ui.GREEN)
msg.normal('Confirm transfer of')
msg.bold(format_amount(mosaic_quantity, definition['divisibility']) + definition['ticker'])
msg.normal('of')
msg.bold(definition['name'])
msg = Text("Confirm mosaic", ui.ICON_SEND, icon_color=ui.GREEN)
msg.normal("Confirm transfer of")
msg.bold(
format_amount(mosaic_quantity, definition["divisibility"])
+ definition["ticker"]
)
msg.normal("of")
msg.bold(definition["name"])
await require_confirm(ctx, msg, ButtonRequestType.ConfirmOutput)
if 'levy' in definition and 'fee' in definition:
if "levy" in definition and "fee" in definition:
levy_msg = _get_levy_msg(definition, mosaic_quantity, common.network)
msg = Text('Confirm mosaic', ui.ICON_SEND, icon_color=ui.GREEN)
msg.normal('Confirm mosaic', 'levy fee of')
msg = Text("Confirm mosaic", ui.ICON_SEND, icon_color=ui.GREEN)
msg.normal("Confirm mosaic", "levy fee of")
msg.bold(levy_msg)
await require_confirm(ctx, msg, ButtonRequestType.ConfirmOutput)
else:
msg = Text('Confirm mosaic', ui.ICON_SEND, icon_color=ui.RED)
msg.bold('Unknown mosaic!')
msg.normal(*split_words('Divisibility and levy cannot be shown for unknown mosaics', 22))
msg = Text("Confirm mosaic", ui.ICON_SEND, icon_color=ui.RED)
msg.bold("Unknown mosaic!")
msg.normal(
*split_words(
"Divisibility and levy cannot be shown for unknown mosaics", 22
)
)
await require_confirm(ctx, msg, ButtonRequestType.ConfirmOutput)
msg = Text('Confirm mosaic', ui.ICON_SEND, icon_color=ui.GREEN)
msg.normal('Confirm transfer of')
msg.bold('%s raw units' % mosaic_quantity)
msg.normal('of')
msg.bold('%s.%s' % (mosaic.namespace, mosaic.mosaic))
msg = Text("Confirm mosaic", ui.ICON_SEND, icon_color=ui.GREEN)
msg.normal("Confirm transfer of")
msg.bold("%s raw units" % mosaic_quantity)
msg.normal("of")
msg.bold("%s.%s" % (mosaic.namespace, mosaic.mosaic))
await require_confirm(ctx, msg, ButtonRequestType.ConfirmOutput)
@ -81,47 +96,50 @@ def _get_xem_amount(transfer: NEMTransfer):
def _get_levy_msg(mosaic_definition, quantity: int, network: int) -> str:
levy_definition = get_mosaic_definition(
mosaic_definition['levy_namespace'],
mosaic_definition['levy_mosaic'],
network)
if mosaic_definition['levy'] == NEMMosaicLevy.MosaicLevy_Absolute:
levy_fee = mosaic_definition['fee']
mosaic_definition["levy_namespace"], mosaic_definition["levy_mosaic"], network
)
if mosaic_definition["levy"] == NEMMosaicLevy.MosaicLevy_Absolute:
levy_fee = mosaic_definition["fee"]
else:
levy_fee = quantity * mosaic_definition['fee'] / NEM_LEVY_PERCENTILE_DIVISOR_ABSOLUTE
return format_amount(
levy_fee,
levy_definition['divisibility']
) + levy_definition['ticker']
async def ask_importance_transfer(ctx, common: NEMTransactionCommon, imp: NEMImportanceTransfer):
levy_fee = (
quantity * mosaic_definition["fee"] / NEM_LEVY_PERCENTILE_DIVISOR_ABSOLUTE
)
return (
format_amount(levy_fee, levy_definition["divisibility"])
+ levy_definition["ticker"]
)
async def ask_importance_transfer(
ctx, common: NEMTransactionCommon, imp: NEMImportanceTransfer
):
if imp.mode == NEMImportanceTransferMode.ImportanceTransfer_Activate:
m = 'Activate'
m = "Activate"
else:
m = 'Deactivate'
await require_confirm_text(ctx, m + ' remote harvesting?')
m = "Deactivate"
await require_confirm_text(ctx, m + " remote harvesting?")
await require_confirm_final(ctx, common.fee)
async def _require_confirm_transfer(ctx, recipient, value):
text = Text('Confirm transfer', ui.ICON_SEND, icon_color=ui.GREEN)
text.bold('Send %s XEM' % format_amount(value, NEM_MAX_DIVISIBILITY))
text.normal('to')
text = Text("Confirm transfer", ui.ICON_SEND, icon_color=ui.GREEN)
text.bold("Send %s XEM" % format_amount(value, NEM_MAX_DIVISIBILITY))
text.normal("to")
text.mono(*split_address(recipient))
await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput)
async def _require_confirm_payload(ctx, payload: bytes, encrypt=False):
payload = str(payload, 'utf-8')
payload = str(payload, "utf-8")
if len(payload) > 48:
payload = payload[:48] + '..'
payload = payload[:48] + ".."
if encrypt:
text = Text('Confirm payload', ui.ICON_SEND, icon_color=ui.GREEN)
text.bold('Encrypted:')
text = Text("Confirm payload", ui.ICON_SEND, icon_color=ui.GREEN)
text.bold("Encrypted:")
text.normal(*split_words(payload, 22))
else:
text = Text('Confirm payload', ui.ICON_SEND, icon_color=ui.RED)
text.bold('Unencrypted:')
text = Text("Confirm payload", ui.ICON_SEND, icon_color=ui.RED)
text.bold("Unencrypted:")
text.normal(*split_words(payload, 22))
await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput)

@ -13,14 +13,19 @@ from ..helpers import (
from ..writers import write_bytes_with_length, write_common, write_uint32, write_uint64
def serialize_transfer(common: NEMTransactionCommon,
transfer: NEMTransfer,
public_key: bytes,
payload: bytes = None,
encrypted: bool = False) -> bytearray:
tx = write_common(common, bytearray(public_key),
NEM_TRANSACTION_TYPE_TRANSFER,
_get_version(common.network, transfer.mosaics))
def serialize_transfer(
common: NEMTransactionCommon,
transfer: NEMTransfer,
public_key: bytes,
payload: bytes = None,
encrypted: bool = False,
) -> bytearray:
tx = write_common(
common,
bytearray(public_key),
NEM_TRANSACTION_TYPE_TRANSFER,
_get_version(common.network, transfer.mosaics),
)
write_bytes_with_length(tx, bytearray(transfer.recipient))
write_uint64(tx, transfer.amount)
@ -52,11 +57,12 @@ def serialize_mosaic(w: bytearray, namespace: str, mosaic: str, quantity: int):
write_uint64(w, quantity)
def serialize_importance_transfer(common: NEMTransactionCommon,
imp: NEMImportanceTransfer,
public_key: bytes) -> bytearray:
w = write_common(common, bytearray(public_key),
NEM_TRANSACTION_TYPE_IMPORTANCE_TRANSFER)
def serialize_importance_transfer(
common: NEMTransactionCommon, imp: NEMImportanceTransfer, public_key: bytes
) -> bytearray:
w = write_common(
common, bytearray(public_key), NEM_TRANSACTION_TYPE_IMPORTANCE_TRANSFER
)
write_uint32(w, imp.mode)
write_bytes_with_length(w, bytearray(imp.public_key))
@ -68,7 +74,7 @@ def get_transfer_payload(transfer: NEMTransfer, node) -> [bytes, bool]:
encrypted = False
if transfer.public_key is not None:
if payload is None:
raise ValueError('Public key provided but no payload to encrypt')
raise ValueError("Public key provided but no payload to encrypt")
payload = _encrypt(node, transfer.public_key, transfer.payload)
encrypted = True

@ -26,7 +26,7 @@ from .helpers import (
def validate(msg: NEMSignTx):
if msg.transaction is None:
raise ProcessError('No common provided')
raise ProcessError("No common provided")
_validate_single_tx(msg)
_validate_common(msg.transaction)
@ -35,7 +35,7 @@ def validate(msg: NEMSignTx):
_validate_common(msg.multisig, True)
_validate_multisig(msg.multisig, msg.transaction.network)
if not msg.multisig and msg.cosigning:
raise ProcessError('No multisig transaction to cosign')
raise ProcessError("No multisig transaction to cosign")
if msg.transfer:
_validate_transfer(msg.transfer, msg.transaction.network)
@ -46,7 +46,9 @@ def validate(msg: NEMSignTx):
if msg.supply_change:
_validate_supply_change(msg.supply_change)
if msg.aggregate_modification:
_validate_aggregate_modification(msg.aggregate_modification, msg.multisig is None)
_validate_aggregate_modification(
msg.aggregate_modification, msg.multisig is None
)
if msg.importance_transfer:
_validate_importance_transfer(msg.importance_transfer)
@ -55,23 +57,24 @@ def validate_network(network: int) -> int:
if network is None:
return NEM_NETWORK_MAINNET
if network not in (NEM_NETWORK_MAINNET, NEM_NETWORK_TESTNET, NEM_NETWORK_MIJIN):
raise ProcessError('Invalid NEM network')
raise ProcessError("Invalid NEM network")
return network
def _validate_single_tx(msg: NEMSignTx):
# ensure exactly one transaction is provided
tx_count = \
bool(msg.transfer) + \
bool(msg.provision_namespace) + \
bool(msg.mosaic_creation) + \
bool(msg.supply_change) + \
bool(msg.aggregate_modification) + \
bool(msg.importance_transfer)
tx_count = (
bool(msg.transfer)
+ bool(msg.provision_namespace)
+ bool(msg.mosaic_creation)
+ bool(msg.supply_change)
+ bool(msg.aggregate_modification)
+ bool(msg.importance_transfer)
)
if tx_count == 0:
raise ProcessError('No transaction provided')
raise ProcessError("No transaction provided")
if tx_count > 1:
raise ProcessError('More than one transaction provided')
raise ProcessError("More than one transaction provided")
def _validate_common(common: NEMTransactionCommon, inner: bool = False):
@ -79,174 +82,196 @@ def _validate_common(common: NEMTransactionCommon, inner: bool = False):
err = None
if common.timestamp is None:
err = 'timestamp'
err = "timestamp"
if common.fee is None:
err = 'fee'
err = "fee"
if common.deadline is None:
err = 'deadline'
err = "deadline"
if not inner and common.signer:
raise ProcessError('Signer not allowed in outer transaction')
raise ProcessError("Signer not allowed in outer transaction")
if inner and common.signer is None:
err = 'signer'
err = "signer"
if err:
if inner:
raise ProcessError('No %s provided in inner transaction' % err)
raise ProcessError("No %s provided in inner transaction" % err)
else:
raise ProcessError('No %s provided' % err)
raise ProcessError("No %s provided" % err)
if common.signer is not None:
_validate_public_key(common.signer, 'Invalid signer public key in inner transaction')
_validate_public_key(
common.signer, "Invalid signer public key in inner transaction"
)
def _validate_public_key(public_key: bytes, err_msg: str):
if not public_key:
raise ProcessError('%s (none provided)' % err_msg)
raise ProcessError("%s (none provided)" % err_msg)
if len(public_key) != NEM_PUBLIC_KEY_SIZE:
raise ProcessError('%s (invalid length)' % err_msg)
raise ProcessError("%s (invalid length)" % err_msg)
def _validate_importance_transfer(importance_transfer: NEMImportanceTransfer):
if importance_transfer.mode is None:
raise ProcessError('No mode provided')
_validate_public_key(importance_transfer.public_key, 'Invalid remote account public key provided')
raise ProcessError("No mode provided")
_validate_public_key(
importance_transfer.public_key, "Invalid remote account public key provided"
)
def _validate_multisig(multisig: NEMTransactionCommon, network: int):
if multisig.network != network:
raise ProcessError('Inner transaction network is different')
_validate_public_key(multisig.signer, 'Invalid multisig signer public key provided')
raise ProcessError("Inner transaction network is different")
_validate_public_key(multisig.signer, "Invalid multisig signer public key provided")
def _validate_aggregate_modification(
aggregate_modification: NEMAggregateModification,
creation: bool = False):
aggregate_modification: NEMAggregateModification, creation: bool = False
):
if creation and not aggregate_modification.modifications:
raise ProcessError('No modifications provided')
raise ProcessError("No modifications provided")
for m in aggregate_modification.modifications:
if not m.type:
raise ProcessError('No modification type provided')
raise ProcessError("No modification type provided")
if m.type not in (
NEMModificationType.CosignatoryModification_Add,
NEMModificationType.CosignatoryModification_Delete
NEMModificationType.CosignatoryModification_Delete,
):
raise ProcessError('Unknown aggregate modification')
raise ProcessError("Unknown aggregate modification")
if creation and m.type == NEMModificationType.CosignatoryModification_Delete:
raise ProcessError('Cannot remove cosignatory when converting account')
_validate_public_key(m.public_key, 'Invalid cosignatory public key provided')
raise ProcessError("Cannot remove cosignatory when converting account")
_validate_public_key(m.public_key, "Invalid cosignatory public key provided")
def _validate_supply_change(supply_change: NEMMosaicSupplyChange):
if supply_change.namespace is None:
raise ProcessError('No namespace provided')
raise ProcessError("No namespace provided")
if supply_change.mosaic is None:
raise ProcessError('No mosaic provided')
raise ProcessError("No mosaic provided")
if supply_change.type is None:
raise ProcessError('No type provided')
elif supply_change.type not in [NEMSupplyChangeType.SupplyChange_Decrease, NEMSupplyChangeType.SupplyChange_Increase]:
raise ProcessError('Invalid supply change type')
raise ProcessError("No type provided")
elif supply_change.type not in [
NEMSupplyChangeType.SupplyChange_Decrease,
NEMSupplyChangeType.SupplyChange_Increase,
]:
raise ProcessError("Invalid supply change type")
if supply_change.delta is None:
raise ProcessError('No delta provided')
raise ProcessError("No delta provided")
def _validate_mosaic_creation(mosaic_creation: NEMMosaicCreation, network: int):
if mosaic_creation.definition is None:
raise ProcessError('No mosaic definition provided')
raise ProcessError("No mosaic definition provided")
if mosaic_creation.sink is None:
raise ProcessError('No creation sink provided')
raise ProcessError("No creation sink provided")
if mosaic_creation.fee is None:
raise ProcessError('No creation sink fee provided')
raise ProcessError("No creation sink fee provided")
if not nem.validate_address(mosaic_creation.sink, network):
raise ProcessError('Invalid creation sink address')
raise ProcessError("Invalid creation sink address")
if mosaic_creation.definition.name is not None:
raise ProcessError('Name not allowed in mosaic creation transactions')
raise ProcessError("Name not allowed in mosaic creation transactions")
if mosaic_creation.definition.ticker is not None:
raise ProcessError('Ticker not allowed in mosaic creation transactions')
raise ProcessError("Ticker not allowed in mosaic creation transactions")
if mosaic_creation.definition.networks:
raise ProcessError('Networks not allowed in mosaic creation transactions')
raise ProcessError("Networks not allowed in mosaic creation transactions")
if mosaic_creation.definition.namespace is None:
raise ProcessError('No mosaic namespace provided')
raise ProcessError("No mosaic namespace provided")
if mosaic_creation.definition.mosaic is None:
raise ProcessError('No mosaic name provided')
if mosaic_creation.definition.supply is not None and mosaic_creation.definition.divisibility is None:
raise ProcessError('Definition divisibility needs to be provided when supply is')
if mosaic_creation.definition.supply is None and mosaic_creation.definition.divisibility is not None:
raise ProcessError('Definition supply needs to be provided when divisibility is')
raise ProcessError("No mosaic name provided")
if (
mosaic_creation.definition.supply is not None
and mosaic_creation.definition.divisibility is None
):
raise ProcessError(
"Definition divisibility needs to be provided when supply is"
)
if (
mosaic_creation.definition.supply is None
and mosaic_creation.definition.divisibility is not None
):
raise ProcessError(
"Definition supply needs to be provided when divisibility is"
)
if mosaic_creation.definition.levy is not None:
if mosaic_creation.definition.fee is None:
raise ProcessError('No levy fee provided')
raise ProcessError("No levy fee provided")
if mosaic_creation.definition.levy_address is None:
raise ProcessError('No levy address provided')
raise ProcessError("No levy address provided")
if mosaic_creation.definition.levy_namespace is None:
raise ProcessError('No levy namespace provided')
raise ProcessError("No levy namespace provided")
if mosaic_creation.definition.levy_mosaic is None:
raise ProcessError('No levy mosaic name provided')
raise ProcessError("No levy mosaic name provided")
if mosaic_creation.definition.divisibility is None:
raise ProcessError('No divisibility provided')
raise ProcessError("No divisibility provided")
if mosaic_creation.definition.supply is None:
raise ProcessError('No supply provided')
raise ProcessError("No supply provided")
if mosaic_creation.definition.mutable_supply is None:
raise ProcessError('No supply mutability provided')
raise ProcessError("No supply mutability provided")
if mosaic_creation.definition.transferable is None:
raise ProcessError('No mosaic transferability provided')
raise ProcessError("No mosaic transferability provided")
if mosaic_creation.definition.description is None:
raise ProcessError('No description provided')
raise ProcessError("No description provided")
if mosaic_creation.definition.divisibility > NEM_MAX_DIVISIBILITY:
raise ProcessError('Invalid divisibility provided')
raise ProcessError("Invalid divisibility provided")
if mosaic_creation.definition.supply > NEM_MAX_SUPPLY:
raise ProcessError('Invalid supply provided')
raise ProcessError("Invalid supply provided")
if not nem.validate_address(mosaic_creation.definition.levy_address, network):
raise ProcessError('Invalid levy address')
raise ProcessError("Invalid levy address")
def _validate_provision_namespace(provision_namespace: NEMProvisionNamespace, network: int):
def _validate_provision_namespace(
provision_namespace: NEMProvisionNamespace, network: int
):
if provision_namespace.namespace is None:
raise ProcessError('No namespace provided')
raise ProcessError("No namespace provided")
if provision_namespace.sink is None:
raise ProcessError('No rental sink provided')
raise ProcessError("No rental sink provided")
if provision_namespace.fee is None:
raise ProcessError('No rental sink fee provided')
raise ProcessError("No rental sink fee provided")
if not nem.validate_address(provision_namespace.sink, network):
raise ProcessError('Invalid rental sink address')
raise ProcessError("Invalid rental sink address")
def _validate_transfer(transfer: NEMTransfer, network: int):
if transfer.recipient is None:
raise ProcessError('No recipient provided')
raise ProcessError("No recipient provided")
if transfer.amount is None:
raise ProcessError('No amount provided')
raise ProcessError("No amount provided")
if transfer.public_key is not None:
_validate_public_key(transfer.public_key, 'Invalid recipient public key')
_validate_public_key(transfer.public_key, "Invalid recipient public key")
if transfer.payload is None:
raise ProcessError('Public key provided but no payload to encrypt')
raise ProcessError("Public key provided but no payload to encrypt")
if transfer.payload:
if len(transfer.payload) > NEM_MAX_PLAIN_PAYLOAD_SIZE:
raise ProcessError('Payload too large')
if transfer.public_key and len(transfer.payload) > NEM_MAX_ENCRYPTED_PAYLOAD_SIZE:
raise ProcessError('Payload too large')
raise ProcessError("Payload too large")
if (
transfer.public_key
and len(transfer.payload) > NEM_MAX_ENCRYPTED_PAYLOAD_SIZE
):
raise ProcessError("Payload too large")
if not nem.validate_address(transfer.recipient, network):
raise ProcessError('Invalid recipient address')
raise ProcessError("Invalid recipient address")
for m in transfer.mosaics:
if m.namespace is None:
raise ProcessError('No mosaic namespace provided')
raise ProcessError("No mosaic namespace provided")
if m.mosaic is None:
raise ProcessError('No mosaic name provided')
raise ProcessError("No mosaic name provided")
if m.quantity is None:
raise ProcessError('No mosaic quantity provided')
raise ProcessError("No mosaic quantity provided")

@ -28,10 +28,12 @@ def write_bytes_with_length(w, buf: bytearray):
write_bytes(w, buf)
def write_common(common: NEMTransactionCommon,
public_key: bytearray,
transaction_type: int,
version: int = None) -> bytearray:
def write_common(
common: NEMTransactionCommon,
public_key: bytearray,
transaction_type: int,
version: int = None,
) -> bytearray:
ret = bytearray()
write_uint32(ret, transaction_type)

@ -14,46 +14,55 @@ from trezor.wire import protobuf_workflow, register
def dispatch_GetPublicKey(*args, **kwargs):
from .get_public_key import get_public_key
return get_public_key(*args, **kwargs)
def dispatch_GetAddress(*args, **kwargs):
from .get_address import get_address
return get_address(*args, **kwargs)
def dispatch_GetEntropy(*args, **kwargs):
from .get_entropy import get_entropy
return get_entropy(*args, **kwargs)
def dispatch_SignTx(*args, **kwargs):
from .sign_tx import sign_tx
return sign_tx(*args, **kwargs)
def dispatch_SignMessage(*args, **kwargs):
from .sign_message import sign_message
return sign_message(*args, **kwargs)
def dispatch_VerifyMessage(*args, **kwargs):
from .verify_message import verify_message
return verify_message(*args, **kwargs)
def dispatch_SignIdentity(*args, **kwargs):
from .sign_identity import sign_identity
return sign_identity(*args, **kwargs)
def dispatch_GetECDHSessionKey(*args, **kwargs):
from .ecdh import get_ecdh_session_key
return get_ecdh_session_key(*args, **kwargs)
def dispatch_CipherKeyValue(*args, **kwargs):
from .cipher_key_value import cipher_key_value
return cipher_key_value(*args, **kwargs)

@ -11,15 +11,15 @@ from apps.common.confirm import require_confirm
async def cipher_key_value(ctx, msg):
if len(msg.value) % 16 > 0:
raise wire.DataError('Value length must be a multiple of 16')
raise wire.DataError("Value length must be a multiple of 16")
encrypt = msg.encrypt
decrypt = not msg.encrypt
if (encrypt and msg.ask_on_encrypt) or (decrypt and msg.ask_on_decrypt):
if encrypt:
title = 'Encrypt value'
title = "Encrypt value"
else:
title = 'Decrypt value'
title = "Decrypt value"
text = Text(title)
text.normal(msg.key)
await require_confirm(ctx, text)
@ -31,8 +31,8 @@ async def cipher_key_value(ctx, msg):
def compute_cipher_key_value(msg, seckey: bytes) -> bytes:
data = msg.key
data += 'E1' if msg.ask_on_encrypt else 'E0'
data += 'D1' if msg.ask_on_decrypt else 'D0'
data += "E1" if msg.ask_on_encrypt else "E0"
data += "D1" if msg.ask_on_decrypt else "D0"
data = hmac.new(seckey, data, sha512).digest()
key = data[:32]
if msg.iv and len(msg.iv) == 16:

@ -15,7 +15,7 @@ from apps.wallet.sign_identity import (
async def get_ecdh_session_key(ctx, msg):
if msg.ecdsa_curve_name is None:
msg.ecdsa_curve_name = 'secp256k1'
msg.ecdsa_curve_name = "secp256k1"
identity = serialize_identity(msg.identity)
@ -24,42 +24,47 @@ async def get_ecdh_session_key(ctx, msg):
address_n = get_ecdh_path(identity, msg.identity.index or 0)
node = await seed.derive_node(ctx, address_n, msg.ecdsa_curve_name)
session_key = ecdh(seckey=node.private_key(),
peer_public_key=msg.peer_public_key,
curve=msg.ecdsa_curve_name)
session_key = ecdh(
seckey=node.private_key(),
peer_public_key=msg.peer_public_key,
curve=msg.ecdsa_curve_name,
)
return ECDHSessionKey(session_key=session_key)
async def require_confirm_ecdh_session_key(ctx, identity):
lines = chunks(serialize_identity_without_proto(identity), 18)
proto = identity.proto.upper() if identity.proto else 'identity'
text = Text('Decrypt %s' % proto)
proto = identity.proto.upper() if identity.proto else "identity"
text = Text("Decrypt %s" % proto)
text.mono(*lines)
await require_confirm(ctx, text)
def get_ecdh_path(identity: str, index: int):
identity_hash = sha256(pack('<I', index) + identity).digest()
identity_hash = sha256(pack("<I", index) + identity).digest()
address_n = (17, ) + unpack('<IIII', identity_hash[:16])
address_n = (17,) + unpack("<IIII", identity_hash[:16])
address_n = [HARDENED | x for x in address_n]
return address_n
def ecdh(seckey: bytes, peer_public_key: bytes, curve: str) -> bytes:
if curve == 'secp256k1':
if curve == "secp256k1":
from trezor.crypto.curve import secp256k1
session_key = secp256k1.multiply(seckey, peer_public_key)
elif curve == 'nist256p1':
elif curve == "nist256p1":
from trezor.crypto.curve import nist256p1
session_key = nist256p1.multiply(seckey, peer_public_key)
elif curve == 'curve25519':
elif curve == "curve25519":
from trezor.crypto.curve import curve25519
if peer_public_key[0] != 0x40:
raise ValueError('Curve25519 public key should start with 0x40')
session_key = b'\x04' + curve25519.multiply(seckey, peer_public_key[1:])
raise ValueError("Curve25519 public key should start with 0x40")
session_key = b"\x04" + curve25519.multiply(seckey, peer_public_key[1:])
else:
raise ValueError('Unsupported curve for ECDH: ' + curve)
raise ValueError("Unsupported curve for ECDH: " + curve)
return session_key

@ -7,7 +7,7 @@ from apps.wallet.sign_tx import addresses
async def get_address(ctx, msg):
coin_name = msg.coin_name or 'Bitcoin'
coin_name = msg.coin_name or "Bitcoin"
coin = coins.by_name(coin_name)
node = await seed.derive_node(ctx, msg.address_n, curve_name=coin.curve_name)
@ -18,7 +18,12 @@ async def get_address(ctx, msg):
while True:
if await show_address(ctx, address_short):
break
if await show_qr(ctx, address.upper() if msg.script_type == InputScriptType.SPENDWITNESS else address):
if await show_qr(
ctx,
address.upper()
if msg.script_type == InputScriptType.SPENDWITNESS
else address,
):
break
return Address(address=address)

@ -8,9 +8,9 @@ from apps.common.confirm import require_confirm
async def get_entropy(ctx, msg):
text = Text('Confirm entropy')
text.bold('Do you really want', 'to send entropy?')
text.normal('Continue only if you', 'know what you are doing!')
text = Text("Confirm entropy")
text.bold("Do you really want", "to send entropy?")
text.normal("Continue only if you", "know what you are doing!")
await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall)
size = min(msg.size, 1024)

@ -12,7 +12,7 @@ from apps.common.confirm import require_confirm
async def get_public_key(ctx, msg):
coin_name = msg.coin_name or 'Bitcoin'
coin_name = msg.coin_name or "Bitcoin"
coin = coins.by_name(coin_name)
curve_name = msg.ecdsa_curve_name
@ -23,13 +23,14 @@ async def get_public_key(ctx, msg):
node_xpub = node.serialize_public(coin.xpub_magic)
pubkey = node.public_key()
if pubkey[0] == 1:
pubkey = b'\x00' + pubkey[1:]
pubkey = b"\x00" + pubkey[1:]
node_type = HDNodeType(
depth=node.depth(),
child_num=node.child_num(),
fingerprint=node.fingerprint(),
chain_code=node.chain_code(),
public_key=pubkey)
public_key=pubkey,
)
if msg.show_display:
await _show_pubkey(ctx, pubkey)
@ -39,9 +40,6 @@ async def get_public_key(ctx, msg):
async def _show_pubkey(ctx, pubkey: bytes):
lines = chunks(hexlify(pubkey).decode(), 18)
text = Text('Confirm public key', ui.ICON_RECEIVE, icon_color=ui.GREEN)
text = Text("Confirm public key", ui.ICON_RECEIVE, icon_color=ui.GREEN)
text.mono(*lines)
return await require_confirm(
ctx,
text,
code=ButtonRequestType.PublicKey)
return await require_confirm(ctx, text, code=ButtonRequestType.PublicKey)

@ -12,7 +12,7 @@ from apps.common.confirm import require_confirm
async def sign_identity(ctx, msg):
if msg.ecdsa_curve_name is None:
msg.ecdsa_curve_name = 'secp256k1'
msg.ecdsa_curve_name = "secp256k1"
identity = serialize_identity(msg.identity)
@ -21,25 +21,40 @@ async def sign_identity(ctx, msg):
address_n = get_identity_path(identity, msg.identity.index or 0)
node = await seed.derive_node(ctx, address_n, msg.ecdsa_curve_name)
coin = coins.by_name('Bitcoin')
if msg.ecdsa_curve_name == 'secp256k1':
coin = coins.by_name("Bitcoin")
if msg.ecdsa_curve_name == "secp256k1":
address = node.address(coin.address_type) # hardcoded bitcoin address type
else:
address = None
pubkey = node.public_key()
if pubkey[0] == 0x01:
pubkey = b'\x00' + pubkey[1:]
pubkey = b"\x00" + pubkey[1:]
seckey = node.private_key()
if msg.identity.proto == 'gpg':
if msg.identity.proto == "gpg":
signature = sign_challenge(
seckey, msg.challenge_hidden, msg.challenge_visual, 'gpg', msg.ecdsa_curve_name)
elif msg.identity.proto == 'ssh':
seckey,
msg.challenge_hidden,
msg.challenge_visual,
"gpg",
msg.ecdsa_curve_name,
)
elif msg.identity.proto == "ssh":
signature = sign_challenge(
seckey, msg.challenge_hidden, msg.challenge_visual, 'ssh', msg.ecdsa_curve_name)
seckey,
msg.challenge_hidden,
msg.challenge_visual,
"ssh",
msg.ecdsa_curve_name,
)
else:
signature = sign_challenge(
seckey, msg.challenge_hidden, msg.challenge_visual, coin, msg.ecdsa_curve_name)
seckey,
msg.challenge_hidden,
msg.challenge_visual,
coin,
msg.ecdsa_curve_name,
)
return SignedIdentity(address=address, public_key=pubkey, signature=signature)
@ -52,22 +67,22 @@ async def require_confirm_sign_identity(ctx, identity, challenge_visual):
lines.append(ui.MONO)
lines.extend(chunks(serialize_identity_without_proto(identity), 18))
proto = identity.proto.upper() if identity.proto else 'identity'
text = Text('Sign %s' % proto)
proto = identity.proto.upper() if identity.proto else "identity"
text = Text("Sign %s" % proto)
text.normal(*lines)
await require_confirm(ctx, text)
def serialize_identity(identity):
s = ''
s = ""
if identity.proto:
s += identity.proto + '://'
s += identity.proto + "://"
if identity.user:
s += identity.user + '@'
s += identity.user + "@"
if identity.host:
s += identity.host
if identity.port:
s += ':' + identity.port
s += ":" + identity.port
if identity.path:
s += identity.path
return s
@ -82,52 +97,53 @@ def serialize_identity_without_proto(identity):
def get_identity_path(identity: str, index: int):
identity_hash = sha256(pack('<I', index) + identity).digest()
identity_hash = sha256(pack("<I", index) + identity).digest()
address_n = (13, ) + unpack('<IIII', identity_hash[:16])
address_n = (13,) + unpack("<IIII", identity_hash[:16])
address_n = [HARDENED | x for x in address_n]
return address_n
def sign_challenge(seckey: bytes,
challenge_hidden: bytes,
challenge_visual: str,
sigtype,
curve: str) -> bytes:
def sign_challenge(
seckey: bytes, challenge_hidden: bytes, challenge_visual: str, sigtype, curve: str
) -> bytes:
from trezor.crypto.hashlib import sha256
if curve == 'secp256k1':
if curve == "secp256k1":
from trezor.crypto.curve import secp256k1
elif curve == 'nist256p1':
elif curve == "nist256p1":
from trezor.crypto.curve import nist256p1
elif curve == 'ed25519':
elif curve == "ed25519":
from trezor.crypto.curve import ed25519
from apps.common.signverify import message_digest
if sigtype == 'gpg':
if sigtype == "gpg":
data = challenge_hidden
elif sigtype == 'ssh':
if curve != 'ed25519':
elif sigtype == "ssh":
if curve != "ed25519":
data = sha256(challenge_hidden).digest()
else:
data = challenge_hidden
else:
# sigtype is coin
challenge = sha256(challenge_hidden).digest() + sha256(challenge_visual).digest()
challenge = (
sha256(challenge_hidden).digest() + sha256(challenge_visual).digest()
)
data = message_digest(sigtype, challenge)
if curve == 'secp256k1':
if curve == "secp256k1":
signature = secp256k1.sign(seckey, data)
elif curve == 'nist256p1':
elif curve == "nist256p1":
signature = nist256p1.sign(seckey, data)
elif curve == 'ed25519':
elif curve == "ed25519":
signature = ed25519.sign(seckey, data)
else:
raise ValueError('Unknown curve')
raise ValueError("Unknown curve")
if curve == 'ed25519':
signature = b'\x00' + signature
elif sigtype == 'gpg' or sigtype == 'ssh':
signature = b'\x00' + signature[1:]
if curve == "ed25519":
signature = b"\x00" + signature
elif sigtype == "gpg" or sigtype == "ssh":
signature = b"\x00" + signature[1:]
return signature

@ -13,7 +13,7 @@ from apps.wallet.sign_tx.addresses import get_address
async def sign_message(ctx, msg):
message = msg.message
address_n = msg.address_n
coin_name = msg.coin_name or 'Bitcoin'
coin_name = msg.coin_name or "Bitcoin"
script_type = msg.script_type or 0
coin = coins.by_name(coin_name)
@ -33,13 +33,13 @@ async def sign_message(ctx, msg):
elif script_type == SPENDWITNESS:
signature = bytes([signature[0] + 8]) + signature[1:]
else:
raise wire.ProcessError('Unsupported script type')
raise wire.ProcessError("Unsupported script type")
return MessageSignature(address=address, signature=signature)
async def require_confirm_sign_message(ctx, message):
message = split_message(message)
text = Text('Sign message')
text = Text("Sign message")
text.normal(*message)
await require_confirm(ctx, text)

@ -22,18 +22,23 @@ class AddressError(Exception):
pass
def get_address(script_type: InputScriptType, coin: CoinInfo, node, multisig=None) -> str:
if script_type == InputScriptType.SPENDADDRESS or script_type == InputScriptType.SPENDMULTISIG:
def get_address(
script_type: InputScriptType, coin: CoinInfo, node, multisig=None
) -> str:
if (
script_type == InputScriptType.SPENDADDRESS
or script_type == InputScriptType.SPENDMULTISIG
):
if multisig: # p2sh multisig
pubkey = node.public_key()
index = multisig_pubkey_index(multisig, pubkey)
if index is None:
raise AddressError(FailureType.ProcessError,
'Public key not found')
raise AddressError(FailureType.ProcessError, "Public key not found")
if coin.address_type_p2sh is None:
raise AddressError(FailureType.ProcessError,
'Multisig not enabled on this coin')
raise AddressError(
FailureType.ProcessError, "Multisig not enabled on this coin"
)
pubkeys = multisig_get_pubkeys(multisig)
address = address_multisig_p2sh(pubkeys, multisig.m, coin)
@ -41,8 +46,7 @@ def get_address(script_type: InputScriptType, coin: CoinInfo, node, multisig=Non
address = address_to_cashaddr(address, coin)
return address
if script_type == InputScriptType.SPENDMULTISIG:
raise AddressError(FailureType.ProcessError,
'Multisig details required')
raise AddressError(FailureType.ProcessError, "Multisig details required")
# p2pkh
address = node.address(coin.address_type)
@ -52,8 +56,9 @@ def get_address(script_type: InputScriptType, coin: CoinInfo, node, multisig=Non
elif script_type == InputScriptType.SPENDWITNESS: # native p2wpkh or native p2wsh
if not coin.segwit or not coin.bech32_prefix:
raise AddressError(FailureType.ProcessError,
'Segwit not enabled on this coin')
raise AddressError(
FailureType.ProcessError, "Segwit not enabled on this coin"
)
# native p2wsh multisig
if multisig is not None:
pubkeys = multisig_get_pubkeys(multisig)
@ -62,10 +67,13 @@ def get_address(script_type: InputScriptType, coin: CoinInfo, node, multisig=Non
# native p2wpkh
return address_p2wpkh(node.public_key(), coin.bech32_prefix)
elif script_type == InputScriptType.SPENDP2SHWITNESS: # p2wpkh or p2wsh nested in p2sh
elif (
script_type == InputScriptType.SPENDP2SHWITNESS
): # p2wpkh or p2wsh nested in p2sh
if not coin.segwit or coin.address_type_p2sh is None:
raise AddressError(FailureType.ProcessError,
'Segwit not enabled on this coin')
raise AddressError(
FailureType.ProcessError, "Segwit not enabled on this coin"
)
# p2wsh multisig nested in p2sh
if multisig is not None:
pubkeys = multisig_get_pubkeys(multisig)
@ -75,14 +83,14 @@ def get_address(script_type: InputScriptType, coin: CoinInfo, node, multisig=Non
return address_p2wpkh_in_p2sh(node.public_key(), coin)
else:
raise AddressError(FailureType.ProcessError,
'Invalid script type')
raise AddressError(FailureType.ProcessError, "Invalid script type")
def address_multisig_p2sh(pubkeys: bytes, m: int, coin: CoinInfo):
if coin.address_type_p2sh is None:
raise AddressError(FailureType.ProcessError,
'Multisig not enabled on this coin')
raise AddressError(
FailureType.ProcessError, "Multisig not enabled on this coin"
)
redeem_script = output_script_multisig(pubkeys, m)
redeem_script_hash = sha256_ripemd160_digest(redeem_script)
return address_p2sh(redeem_script_hash, coin)
@ -90,8 +98,9 @@ def address_multisig_p2sh(pubkeys: bytes, m: int, coin: CoinInfo):
def address_multisig_p2wsh_in_p2sh(pubkeys: bytes, m: int, coin: CoinInfo):
if coin.address_type_p2sh is None:
raise AddressError(FailureType.ProcessError,
'Multisig not enabled on this coin')
raise AddressError(
FailureType.ProcessError, "Multisig not enabled on this coin"
)
witness_script = output_script_multisig(pubkeys, m)
witness_script_hash = sha256(witness_script).digest()
return address_p2wsh_in_p2sh(witness_script_hash, coin)
@ -99,8 +108,9 @@ def address_multisig_p2wsh_in_p2sh(pubkeys: bytes, m: int, coin: CoinInfo):
def address_multisig_p2wsh(pubkeys: bytes, m: int, hrp: str):
if not hrp:
raise AddressError(FailureType.ProcessError,
'Multisig not enabled on this coin')
raise AddressError(
FailureType.ProcessError, "Multisig not enabled on this coin"
)
witness_script = output_script_multisig(pubkeys, m)
witness_script_hash = sha256(witness_script).digest()
return address_p2wsh(witness_script_hash, hrp)
@ -133,24 +143,21 @@ def address_p2wpkh(pubkey: bytes, hrp: str) -> str:
pubkeyhash = ecdsa_hash_pubkey(pubkey)
address = bech32.encode(hrp, _BECH32_WITVER, pubkeyhash)
if address is None:
raise AddressError(FailureType.ProcessError,
'Invalid address')
raise AddressError(FailureType.ProcessError, "Invalid address")
return address
def address_p2wsh(witness_script_hash: bytes, hrp: str) -> str:
address = bech32.encode(hrp, _BECH32_WITVER, witness_script_hash)
if address is None:
raise AddressError(FailureType.ProcessError,
'Invalid address')
raise AddressError(FailureType.ProcessError, "Invalid address")
return address
def decode_bech32_address(prefix: str, address: str) -> bytes:
witver, raw = bech32.decode(prefix, address)
if witver != _BECH32_WITVER:
raise AddressError(FailureType.ProcessError,
'Invalid address witness program')
raise AddressError(FailureType.ProcessError, "Invalid address witness program")
return bytes(raw)
@ -162,7 +169,7 @@ def address_to_cashaddr(address: str, coin: CoinInfo) -> str:
elif version == coin.address_type_p2sh:
version = cashaddr.ADDRESS_TYPE_P2SH
else:
raise ValueError('Unknown cashaddr address type')
raise ValueError("Unknown cashaddr address type")
return cashaddr.encode(coin.cashaddr_prefix, version, data)
@ -170,7 +177,7 @@ def ecdsa_hash_pubkey(pubkey: bytes) -> bytes:
if pubkey[0] == 0x04:
ensure(len(pubkey) == 65) # uncompressed format
elif pubkey[0] == 0x00:
ensure(len(pubkey) == 1) # point at infinity
ensure(len(pubkey) == 1) # point at infinity
else:
ensure(len(pubkey) == 33) # compresssed format
h = sha256(pubkey).digest()
@ -179,7 +186,9 @@ def ecdsa_hash_pubkey(pubkey: bytes) -> bytes:
def address_short(coin: CoinInfo, address: str) -> str:
if coin.cashaddr_prefix is not None and address.startswith(coin.cashaddr_prefix + ':'):
return address[len(coin.cashaddr_prefix) + 1:]
if coin.cashaddr_prefix is not None and address.startswith(
coin.cashaddr_prefix + ":"
):
return address[len(coin.cashaddr_prefix) + 1 :]
else:
return address

@ -20,14 +20,12 @@ from apps.common.coininfo import CoinInfo
class UiConfirmOutput:
def __init__(self, output: TxOutputType, coin: CoinInfo):
self.output = output
self.coin = coin
class UiConfirmTotal:
def __init__(self, spending: int, fee: int, coin: CoinInfo):
self.spending = spending
self.fee = fee
@ -35,14 +33,12 @@ class UiConfirmTotal:
class UiConfirmFeeOverThreshold:
def __init__(self, fee: int, coin: CoinInfo):
self.fee = fee
self.coin = coin
class UiConfirmForeignAddress:
def __init__(self, address_n: list, coin: CoinInfo):
self.address_n = address_n
self.coin = coin
@ -64,7 +60,7 @@ def confirm_foreign_address(address_n: list, coin: CoinInfo):
return (yield UiConfirmForeignAddress(address_n, coin))
def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None):
def request_tx_meta(tx_req: TxRequest, tx_hash: bytes = None):
tx_req.request_type = TXMETA
tx_req.details.tx_hash = tx_hash
tx_req.details.request_index = None
@ -73,7 +69,9 @@ def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None):
return sanitize_tx_meta(ack.tx)
def request_tx_extra_data(tx_req: TxRequest, offset: int, size: int, tx_hash: bytes=None):
def request_tx_extra_data(
tx_req: TxRequest, offset: int, size: int, tx_hash: bytes = None
):
tx_req.request_type = TXEXTRADATA
tx_req.details.extra_data_offset = offset
tx_req.details.extra_data_len = size
@ -84,7 +82,7 @@ def request_tx_extra_data(tx_req: TxRequest, offset: int, size: int, tx_hash: by
return ack.tx.extra_data
def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes = None):
tx_req.request_type = TXINPUT
tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash
@ -93,7 +91,7 @@ def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
return sanitize_tx_input(ack.tx)
def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None):
def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes = None):
tx_req.request_type = TXOUTPUT
tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash
@ -121,7 +119,7 @@ def sanitize_sign_tx(tx: SignTx) -> SignTx:
tx.lock_time = tx.lock_time if tx.lock_time is not None else 0
tx.inputs_count = tx.inputs_count if tx.inputs_count is not None else 0
tx.outputs_count = tx.outputs_count if tx.outputs_count is not None else 0
tx.coin_name = tx.coin_name if tx.coin_name is not None else 'Bitcoin'
tx.coin_name = tx.coin_name if tx.coin_name is not None else "Bitcoin"
tx.expiry = tx.expiry if tx.expiry is not None else 0
tx.overwintered = tx.overwintered if tx.overwintered is not None else False
return tx

@ -10,7 +10,7 @@ from apps.wallet.sign_tx import addresses
def format_coin_amount(amount, coin):
return '%s %s' % (format_amount(amount, 8), coin.coin_shortcut)
return "%s %s" % (format_amount(amount, 8), coin.coin_shortcut)
def split_address(address):
@ -25,39 +25,36 @@ async def confirm_output(ctx, output, coin):
if output.script_type == OutputScriptType.PAYTOOPRETURN:
data = hexlify(output.op_return_data).decode()
if len(data) >= 18 * 5:
data = data[:(18 * 5 - 3)] + '...'
text = Text('OP_RETURN', ui.ICON_SEND, icon_color=ui.GREEN)
data = data[: (18 * 5 - 3)] + "..."
text = Text("OP_RETURN", ui.ICON_SEND, icon_color=ui.GREEN)
text.mono(*split_op_return(data))
else:
address = output.address
address_short = addresses.address_short(coin, address)
text = Text('Confirm sending', ui.ICON_SEND, icon_color=ui.GREEN)
text.normal(format_coin_amount(output.amount, coin) + ' to')
text = Text("Confirm sending", ui.ICON_SEND, icon_color=ui.GREEN)
text.normal(format_coin_amount(output.amount, coin) + " to")
text.mono(*split_address(address_short))
return await confirm(ctx, text, ButtonRequestType.ConfirmOutput)
async def confirm_total(ctx, spending, fee, coin):
text = Text('Confirm transaction', ui.ICON_SEND, icon_color=ui.GREEN)
text.normal('Total amount:')
text = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
text.normal("Total amount:")
text.bold(format_coin_amount(spending, coin))
text.normal('including fee:')
text.normal("including fee:")
text.bold(format_coin_amount(fee, coin))
return await hold_to_confirm(ctx, text, ButtonRequestType.SignTx)
async def confirm_feeoverthreshold(ctx, fee, coin):
text = Text('High fee', ui.ICON_SEND, icon_color=ui.GREEN)
text.normal('The fee of')
text = Text("High fee", ui.ICON_SEND, icon_color=ui.GREEN)
text.normal("The fee of")
text.bold(format_coin_amount(fee, coin))
text.normal('is unexpectedly high.', 'Continue?')
text.normal("is unexpectedly high.", "Continue?")
return await confirm(ctx, text, ButtonRequestType.FeeOverThreshold)
async def confirm_foreign_address(ctx, address_n, coin):
text = Text('Confirm sending', ui.ICON_SEND, icon_color=ui.RED)
text.normal(
'Trying to spend',
'coins from another chain.',
'Continue?')
text = Text("Confirm sending", ui.ICON_SEND, icon_color=ui.RED)
text.normal("Trying to spend", "coins from another chain.", "Continue?")
return await confirm(ctx, text, ButtonRequestType.SignTx)

@ -40,12 +40,12 @@ def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes:
n = len(pubkeys)
if n < 1 or n > 15 or m < 1 or m > 15:
raise MultisigError(FailureType.DataError, 'Invalid multisig parameters')
raise MultisigError(FailureType.DataError, "Invalid multisig parameters")
for hd in pubkeys:
d = hd.node
if len(d.public_key) != 33 or len(d.chain_code) != 32:
raise MultisigError(FailureType.DataError, 'Invalid multisig parameters')
raise MultisigError(FailureType.DataError, "Invalid multisig parameters")
# casting to bytes(), sorting on bytearray() is not supported in MicroPython
pubkeys = sorted(pubkeys, key=lambda hd: bytes(hd.node.public_key))
@ -68,8 +68,7 @@ def multisig_pubkey_index(multisig: MultisigRedeemScriptType, pubkey: bytes) ->
for i, hd in enumerate(multisig.pubkeys):
if multisig_get_pubkey(hd) == pubkey:
return i
raise MultisigError(FailureType.DataError,
'Pubkey not found in multisig script')
raise MultisigError(FailureType.DataError, "Pubkey not found in multisig script")
def multisig_get_pubkey(hd: HDNodePathType) -> bytes:
@ -80,7 +79,8 @@ def multisig_get_pubkey(hd: HDNodePathType) -> bytes:
fingerprint=n.fingerprint,
child_num=n.child_num,
chain_code=n.chain_code,
public_key=n.public_key)
public_key=n.public_key,
)
for i in p:
node.derive(i, True)
return node.public_key()

@ -28,11 +28,10 @@ class Zip143Error(ValueError):
class Zip143:
def __init__(self):
self.h_prevouts = HashWriter(blake2b, outlen=32, personal=b'ZcashPrevoutHash')
self.h_sequence = HashWriter(blake2b, outlen=32, personal=b'ZcashSequencHash')
self.h_outputs = HashWriter(blake2b, outlen=32, personal=b'ZcashOutputsHash')
self.h_prevouts = HashWriter(blake2b, outlen=32, personal=b"ZcashPrevoutHash")
self.h_sequence = HashWriter(blake2b, outlen=32, personal=b"ZcashSequencHash")
self.h_outputs = HashWriter(blake2b, outlen=32, personal=b"ZcashOutputsHash")
def add_prevouts(self, txi: TxInputType):
write_bytes_rev(self.h_prevouts, txi.prev_hash)
@ -53,31 +52,42 @@ class Zip143:
def get_outputs_hash(self) -> bytes:
return get_tx_hash(self.h_outputs)
def preimage_hash(self, coin: CoinInfo, tx: SignTx, txi: TxInputType, pubkeyhash: bytes, sighash: int) -> bytes:
h_preimage = HashWriter(blake2b, outlen=32, personal=b'ZcashSigHash\x19\x1b\xa8\x5b') # BRANCH_ID = 0x5ba81b19
def preimage_hash(
self,
coin: CoinInfo,
tx: SignTx,
txi: TxInputType,
pubkeyhash: bytes,
sighash: int,
) -> bytes:
h_preimage = HashWriter(
blake2b, outlen=32, personal=b"ZcashSigHash\x19\x1b\xa8\x5b"
) # BRANCH_ID = 0x5ba81b19
assert tx.overwintered
write_uint32(h_preimage, tx.version | OVERWINTERED) # 1. nVersion | fOverwintered
write_uint32(h_preimage, coin.version_group_id) # 2. nVersionGroupId
write_uint32(
h_preimage, tx.version | OVERWINTERED
) # 1. nVersion | fOverwintered
write_uint32(h_preimage, coin.version_group_id) # 2. nVersionGroupId
write_bytes(h_preimage, bytearray(self.get_prevouts_hash())) # 3. hashPrevouts
write_bytes(h_preimage, bytearray(self.get_sequence_hash())) # 4. hashSequence
write_bytes(h_preimage, bytearray(self.get_outputs_hash())) # 5. hashOutputs
write_bytes(h_preimage, b'\x00' * 32) # 6. hashJoinSplits
write_uint32(h_preimage, tx.lock_time) # 7. nLockTime
write_uint32(h_preimage, tx.expiry) # 8. expiryHeight
write_uint32(h_preimage, sighash) # 9. nHashType
write_bytes(h_preimage, bytearray(self.get_outputs_hash())) # 5. hashOutputs
write_bytes(h_preimage, b"\x00" * 32) # 6. hashJoinSplits
write_uint32(h_preimage, tx.lock_time) # 7. nLockTime
write_uint32(h_preimage, tx.expiry) # 8. expiryHeight
write_uint32(h_preimage, sighash) # 9. nHashType
write_bytes_rev(h_preimage, txi.prev_hash) # 10a. outpoint
write_bytes_rev(h_preimage, txi.prev_hash) # 10a. outpoint
write_uint32(h_preimage, txi.prev_index)
script_code = self.derive_script_code(txi, pubkeyhash) # 10b. scriptCode
script_code = self.derive_script_code(txi, pubkeyhash) # 10b. scriptCode
write_varint(h_preimage, len(script_code))
write_bytes(h_preimage, script_code)
write_uint64(h_preimage, txi.amount) # 10c. value
write_uint64(h_preimage, txi.amount) # 10c. value
write_uint32(h_preimage, txi.sequence) # 10d. nSequence
write_uint32(h_preimage, txi.sequence) # 10d. nSequence
return get_tx_hash(h_preimage)
@ -86,12 +96,16 @@ class Zip143:
def derive_script_code(self, txi: TxInputType, pubkeyhash: bytes) -> bytearray:
if txi.multisig:
return output_script_multisig(multisig_get_pubkeys(txi.multisig), txi.multisig.m)
return output_script_multisig(
multisig_get_pubkeys(txi.multisig), txi.multisig.m
)
p2pkh = txi.script_type == InputScriptType.SPENDADDRESS
if p2pkh:
return output_script_p2pkh(pubkeyhash)
else:
raise Zip143Error(FailureType.DataError,
'Unknown input script type for zip143 script code')
raise Zip143Error(
FailureType.DataError,
"Unknown input script type for zip143 script code",
)

@ -20,7 +20,7 @@ def advance():
def report_init():
ui.display.clear()
ui.header('Signing transaction')
ui.header("Signing transaction")
def report():

@ -51,7 +51,9 @@ def output_script_p2sh(scripthash: bytes) -> bytearray:
return s
def script_replay_protection_bip115(block_hash: bytes, block_height: int) -> bytearray:
def script_replay_protection_bip115(
block_hash: bytes, block_height: bytes
) -> bytearray:
if block_hash is None or block_height is None:
return bytearray()
assert len(block_hash) == 32

@ -24,7 +24,6 @@ class Bip143Error(ValueError):
class Bip143:
def __init__(self):
self.h_prevouts = HashWriter(sha256)
self.h_sequence = HashWriter(sha256)
@ -49,27 +48,34 @@ class Bip143:
def get_outputs_hash(self, coin: CoinInfo) -> bytes:
return get_tx_hash(self.h_outputs, double=coin.sign_hash_double)
def preimage_hash(self, coin: CoinInfo, tx: SignTx, txi: TxInputType, pubkeyhash: bytes, sighash: int) -> bytes:
def preimage_hash(
self,
coin: CoinInfo,
tx: SignTx,
txi: TxInputType,
pubkeyhash: bytes,
sighash: int,
) -> bytes:
h_preimage = HashWriter(sha256)
assert not tx.overwintered
write_uint32(h_preimage, tx.version) # nVersion
write_uint32(h_preimage, tx.version) # nVersion
write_bytes(h_preimage, bytearray(self.get_prevouts_hash(coin))) # hashPrevouts
write_bytes(h_preimage, bytearray(self.get_sequence_hash(coin))) # hashSequence
write_bytes_rev(h_preimage, txi.prev_hash) # outpoint
write_uint32(h_preimage, txi.prev_index) # outpoint
write_bytes_rev(h_preimage, txi.prev_hash) # outpoint
write_uint32(h_preimage, txi.prev_index) # outpoint
script_code = self.derive_script_code(txi, pubkeyhash) # scriptCode
script_code = self.derive_script_code(txi, pubkeyhash) # scriptCode
write_varint(h_preimage, len(script_code))
write_bytes(h_preimage, script_code)
write_uint64(h_preimage, txi.amount) # amount
write_uint32(h_preimage, txi.sequence) # nSequence
write_bytes(h_preimage, bytearray(self.get_outputs_hash(coin))) # hashOutputs
write_uint32(h_preimage, tx.lock_time) # nLockTime
write_uint32(h_preimage, sighash) # nHashType
write_uint64(h_preimage, txi.amount) # amount
write_uint32(h_preimage, txi.sequence) # nSequence
write_bytes(h_preimage, bytearray(self.get_outputs_hash(coin))) # hashOutputs
write_uint32(h_preimage, tx.lock_time) # nLockTime
write_uint32(h_preimage, sighash) # nHashType
return get_tx_hash(h_preimage, double=coin.sign_hash_double)
@ -78,16 +84,22 @@ class Bip143:
def derive_script_code(self, txi: TxInputType, pubkeyhash: bytes) -> bytearray:
if txi.multisig:
return output_script_multisig(multisig_get_pubkeys(txi.multisig), txi.multisig.m)
p2pkh = (txi.script_type == InputScriptType.SPENDWITNESS or
txi.script_type == InputScriptType.SPENDP2SHWITNESS or
txi.script_type == InputScriptType.SPENDADDRESS)
return output_script_multisig(
multisig_get_pubkeys(txi.multisig), txi.multisig.m
)
p2pkh = (
txi.script_type == InputScriptType.SPENDWITNESS
or txi.script_type == InputScriptType.SPENDP2SHWITNESS
or txi.script_type == InputScriptType.SPENDADDRESS
)
if p2pkh:
# for p2wpkh in p2sh or native p2wpkh
# the scriptCode is a classic p2pkh
return output_script_p2pkh(pubkeyhash)
else:
raise Bip143Error(FailureType.DataError,
'Unknown input script type for bip143 script code')
raise Bip143Error(
FailureType.DataError,
"Unknown input script type for bip143 script code",
)

@ -94,35 +94,40 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
if txi.multisig:
multifp.add(txi.multisig)
if txi.script_type in (InputScriptType.SPENDWITNESS,
InputScriptType.SPENDP2SHWITNESS):
if txi.script_type in (
InputScriptType.SPENDWITNESS,
InputScriptType.SPENDP2SHWITNESS,
):
if not coin.segwit:
raise SigningError(FailureType.DataError,
'Segwit not enabled on this coin')
raise SigningError(
FailureType.DataError, "Segwit not enabled on this coin"
)
if not txi.amount:
raise SigningError(FailureType.DataError,
'Segwit input without amount')
raise SigningError(FailureType.DataError, "Segwit input without amount")
segwit[i] = True
segwit_in += txi.amount
total_in += txi.amount
elif txi.script_type in (InputScriptType.SPENDADDRESS,
InputScriptType.SPENDMULTISIG):
elif txi.script_type in (
InputScriptType.SPENDADDRESS,
InputScriptType.SPENDMULTISIG,
):
if coin.force_bip143 or tx.overwintered:
if not txi.amount:
raise SigningError(FailureType.DataError,
'BIP/ZIP 143 input without amount')
raise SigningError(
FailureType.DataError, "BIP/ZIP 143 input without amount"
)
segwit[i] = False
segwit_in += txi.amount
total_in += txi.amount
else:
segwit[i] = False
total_in += await get_prevtx_output_value(
coin, tx_req, txi.prev_hash, txi.prev_index)
coin, tx_req, txi.prev_hash, txi.prev_index
)
else:
raise SigningError(FailureType.DataError,
'Wrong input script type')
raise SigningError(FailureType.DataError, "Wrong input script type")
for o in range(tx.outputs_count):
# STAGE_REQUEST_3_OUTPUT
@ -135,8 +140,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
# output is change and does not need confirmation
change_out = txo.amount
elif not await confirm_output(txo, coin):
raise SigningError(FailureType.ActionCancelled,
'Output cancelled')
raise SigningError(FailureType.ActionCancelled, "Output cancelled")
write_tx_output(h_first, txo_bin)
hash143.add_output(txo_bin)
@ -144,18 +148,15 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
fee = total_in - total_out
if fee < 0:
raise SigningError(FailureType.NotEnoughFunds,
'Not enough funds')
raise SigningError(FailureType.NotEnoughFunds, "Not enough funds")
# fee > (coin.maxfee per byte * tx size)
if fee > (coin.maxfee_kb / 1000) * (weight.get_total() / 4):
if not await confirm_feeoverthreshold(fee, coin):
raise SigningError(FailureType.ActionCancelled,
'Signing cancelled')
raise SigningError(FailureType.ActionCancelled, "Signing cancelled")
if not await confirm_total(total_out - change_out, fee, coin):
raise SigningError(FailureType.ActionCancelled,
'Total cancelled')
raise SigningError(FailureType.ActionCancelled, "Total cancelled")
return h_first, hash143, segwit, total_in, wallet_path
@ -191,11 +192,14 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
# STAGE_REQUEST_SEGWIT_INPUT
txi_sign = await request_tx_input(tx_req, i_sign)
is_segwit = (txi_sign.script_type == InputScriptType.SPENDWITNESS or
txi_sign.script_type == InputScriptType.SPENDP2SHWITNESS)
is_segwit = (
txi_sign.script_type == InputScriptType.SPENDWITNESS
or txi_sign.script_type == InputScriptType.SPENDP2SHWITNESS
)
if not is_segwit:
raise SigningError(FailureType.ProcessError,
'Transaction has changed during signing')
raise SigningError(
FailureType.ProcessError, "Transaction has changed during signing"
)
input_check_wallet_path(txi_sign, wallet_path)
key_sign = node_derive(root, txi_sign.address_n)
@ -203,7 +207,8 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
txi_sign.script_sig = input_derive_script(coin, txi_sign, key_sign_pub)
w_txi = bytearray_with_cap(
7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4
)
if i_sign == 0: # serializing first input => prepend headers
write_bytes(w_txi, get_tx_header(coin, tx, True))
write_tx_input(w_txi, txi_sign)
@ -215,17 +220,21 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
txi_sign = await request_tx_input(tx_req, i_sign)
input_check_wallet_path(txi_sign, wallet_path)
is_bip143 = (txi_sign.script_type == InputScriptType.SPENDADDRESS or
txi_sign.script_type == InputScriptType.SPENDMULTISIG)
is_bip143 = (
txi_sign.script_type == InputScriptType.SPENDADDRESS
or txi_sign.script_type == InputScriptType.SPENDMULTISIG
)
if not is_bip143 or txi_sign.amount > authorized_in:
raise SigningError(FailureType.ProcessError,
'Transaction has changed during signing')
raise SigningError(
FailureType.ProcessError, "Transaction has changed during signing"
)
authorized_in -= txi_sign.amount
key_sign = node_derive(root, txi_sign.address_n)
key_sign_pub = key_sign.public_key()
hash143_hash = hash143.preimage_hash(
coin, tx, txi_sign, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin))
coin, tx, txi_sign, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin)
)
# if multisig, check if singing with a key that is included in multisig
if txi_sign.multisig:
@ -237,9 +246,11 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
# serialize input with correct signature
txi_sign.script_sig = input_derive_script(
coin, txi_sign, key_sign_pub, signature)
coin, txi_sign, key_sign_pub, signature
)
w_txi_sign = bytearray_with_cap(
5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4
)
if i_sign == 0: # serializing first input => prepend headers
write_bytes(w_txi_sign, get_tx_header(coin, tx))
write_tx_input(w_txi_sign, txi_sign)
@ -254,10 +265,12 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
h_second = HashWriter(sha256)
if tx.overwintered:
write_uint32(h_sign, tx.version | OVERWINTERED) # nVersion | fOverwintered
write_uint32(h_sign, coin.version_group_id) # nVersionGroupId
write_uint32(
h_sign, tx.version | OVERWINTERED
) # nVersion | fOverwintered
write_uint32(h_sign, coin.version_group_id) # nVersionGroupId
else:
write_uint32(h_sign, tx.version) # nVersion
write_uint32(h_sign, tx.version) # nVersion
write_varint(h_sign, tx.inputs_count)
@ -274,15 +287,21 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
# to the previous tx's scriptPubKey (P2PKH) or a redeem script (P2SH)
if txi_sign.script_type == InputScriptType.SPENDMULTISIG:
txi_sign.script_sig = output_script_multisig(
multisig_get_pubkeys(txi_sign.multisig),
txi_sign.multisig.m)
multisig_get_pubkeys(txi_sign.multisig), txi_sign.multisig.m
)
elif txi_sign.script_type == InputScriptType.SPENDADDRESS:
txi_sign.script_sig = output_script_p2pkh(ecdsa_hash_pubkey(key_sign_pub))
txi_sign.script_sig = output_script_p2pkh(
ecdsa_hash_pubkey(key_sign_pub)
)
if coin.bip115:
txi_sign.script_sig += script_replay_protection_bip115(txi_sign.prev_block_hash_bip115, txi_sign.prev_block_height_bip115)
txi_sign.script_sig += script_replay_protection_bip115(
txi_sign.prev_block_hash_bip115,
txi_sign.prev_block_height_bip115,
)
else:
raise SigningError(FailureType.ProcessError,
'Unknown transaction type')
raise SigningError(
FailureType.ProcessError, "Unknown transaction type"
)
else:
txi.script_sig = bytes()
write_tx_input(h_sign, txi)
@ -300,29 +319,34 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
write_uint32(h_sign, tx.lock_time)
if tx.overwintered:
write_uint32(h_sign, tx.expiry) # expiryHeight
write_varint(h_sign, 0) # nJoinSplit
write_varint(h_sign, 0) # nJoinSplit
write_uint32(h_sign, get_hash_type(coin))
# check the control digests
if get_tx_hash(h_first, False) != get_tx_hash(h_second):
raise SigningError(FailureType.ProcessError,
'Transaction has changed during signing')
raise SigningError(
FailureType.ProcessError, "Transaction has changed during signing"
)
# if multisig, check if singing with a key that is included in multisig
if txi_sign.multisig:
multisig_pubkey_index(txi_sign.multisig, key_sign_pub)
# compute the signature from the tx digest
signature = ecdsa_sign(key_sign, get_tx_hash(h_sign, double=coin.sign_hash_double))
signature = ecdsa_sign(
key_sign, get_tx_hash(h_sign, double=coin.sign_hash_double)
)
tx_ser.signature_index = i_sign
tx_ser.signature = signature
# serialize input with correct signature
txi_sign.script_sig = input_derive_script(
coin, txi_sign, key_sign_pub, signature)
coin, txi_sign, key_sign_pub, signature
)
w_txi_sign = bytearray_with_cap(
5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4
)
if i_sign == 0: # serializing first input => prepend headers
write_bytes(w_txi_sign, get_tx_header(coin, tx))
write_tx_input(w_txi_sign, txi_sign)
@ -338,8 +362,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
# serialize output
w_txo_bin = bytearray_with_cap(
5 + 8 + 5 + len(txo_bin.script_pubkey) + 4)
w_txo_bin = bytearray_with_cap(5 + 8 + 5 + len(txo_bin.script_pubkey) + 4)
if o == 0: # serializing first output => prepend outputs count
write_varint(w_txo_bin, tx.outputs_count)
write_tx_output(w_txo_bin, txo_bin)
@ -359,23 +382,29 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
txi = await request_tx_input(tx_req, i)
input_check_wallet_path(txi, wallet_path)
is_segwit = (txi.script_type == InputScriptType.SPENDWITNESS or
txi.script_type == InputScriptType.SPENDP2SHWITNESS)
is_segwit = (
txi.script_type == InputScriptType.SPENDWITNESS
or txi.script_type == InputScriptType.SPENDP2SHWITNESS
)
if not is_segwit or txi.amount > authorized_in:
raise SigningError(FailureType.ProcessError,
'Transaction has changed during signing')
raise SigningError(
FailureType.ProcessError, "Transaction has changed during signing"
)
authorized_in -= txi.amount
key_sign = node_derive(root, txi.address_n)
key_sign_pub = key_sign.public_key()
hash143_hash = hash143.preimage_hash(
coin, tx, txi, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin))
coin, tx, txi, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin)
)
signature = ecdsa_sign(key_sign, hash143_hash)
if txi.multisig:
# find out place of our signature based on the pubkey
signature_index = multisig_pubkey_index(txi.multisig, key_sign_pub)
witness = witness_p2wsh(txi.multisig, signature, signature_index, get_hash_type(coin))
witness = witness_p2wsh(
txi.multisig, signature, signature_index, get_hash_type(coin)
)
else:
witness = witness_p2wpkh(signature, key_sign_pub, get_hash_type(coin))
@ -392,12 +421,14 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
write_uint32(tx_ser.serialized_tx, tx.lock_time)
if tx.overwintered:
write_uint32(tx_ser.serialized_tx, tx.expiry) # expiryHeight
write_varint(tx_ser.serialized_tx, 0) # nJoinSplit
write_varint(tx_ser.serialized_tx, 0) # nJoinSplit
await request_tx_finish(tx_req)
async def get_prevtx_output_value(coin: CoinInfo, tx_req: TxRequest, prev_hash: bytes, prev_index: int) -> int:
async def get_prevtx_output_value(
coin: CoinInfo, tx_req: TxRequest, prev_hash: bytes, prev_index: int
) -> int:
total_out = 0 # sum of output amounts
# STAGE_REQUEST_2_PREV_META
@ -407,9 +438,9 @@ async def get_prevtx_output_value(coin: CoinInfo, tx_req: TxRequest, prev_hash:
if tx.overwintered:
write_uint32(txh, tx.version | OVERWINTERED) # nVersion | fOverwintered
write_uint32(txh, coin.version_group_id) # nVersionGroupId
write_uint32(txh, coin.version_group_id) # nVersionGroupId
else:
write_uint32(txh, tx.version) # nVersion
write_uint32(txh, tx.version) # nVersion
write_varint(txh, tx.inputs_cnt)
@ -440,8 +471,7 @@ async def get_prevtx_output_value(coin: CoinInfo, tx_req: TxRequest, prev_hash:
ofs += len(data)
if get_tx_hash(txh, double=coin.sign_hash_double, reverse=True) != prev_hash:
raise SigningError(FailureType.ProcessError,
'Encountered invalid prev_hash')
raise SigningError(FailureType.ProcessError, "Encountered invalid prev_hash")
return total_out
@ -463,7 +493,7 @@ def get_tx_header(coin: CoinInfo, tx: SignTx, segwit: bool = False):
w_txi = bytearray()
if tx.overwintered:
write_uint32(w_txi, tx.version | OVERWINTERED) # nVersion | fOverwintered
write_uint32(w_txi, coin.version_group_id) # nVersionGroupId
write_uint32(w_txi, coin.version_group_id) # nVersionGroupId
else:
write_uint32(w_txi, tx.version) # nVersion
if segwit:
@ -482,33 +512,36 @@ def output_derive_script(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode) ->
if o.script_type == OutputScriptType.PAYTOOPRETURN:
# op_return output
if o.amount != 0:
raise SigningError(FailureType.DataError,
'OP_RETURN output with non-zero amount')
raise SigningError(
FailureType.DataError, "OP_RETURN output with non-zero amount"
)
return output_script_paytoopreturn(o.op_return_data)
if o.address_n:
# change output
if o.address:
raise SigningError(FailureType.DataError, 'Address in change output')
raise SigningError(FailureType.DataError, "Address in change output")
o.address = get_address_for_change(o, coin, root)
else:
if not o.address:
raise SigningError(FailureType.DataError, 'Missing address')
raise SigningError(FailureType.DataError, "Missing address")
if coin.bech32_prefix and o.address.startswith(coin.bech32_prefix):
# p2wpkh or p2wsh
witprog = decode_bech32_address(coin.bech32_prefix, o.address)
return output_script_native_p2wpkh_or_p2wsh(witprog)
if coin.cashaddr_prefix is not None and o.address.startswith(coin.cashaddr_prefix + ':'):
prefix, addr = o.address.split(':')
if coin.cashaddr_prefix is not None and o.address.startswith(
coin.cashaddr_prefix + ":"
):
prefix, addr = o.address.split(":")
version, data = cashaddr.decode(prefix, addr)
if version == cashaddr.ADDRESS_TYPE_P2KH:
version = coin.address_type
elif version == cashaddr.ADDRESS_TYPE_P2SH:
version = coin.address_type_p2sh
else:
raise ValueError('Unknown cashaddr address type')
raise ValueError("Unknown cashaddr address type")
raw_address = bytes([version]) + data
else:
raw_address = base58.decode_check(o.address, coin.b58_hash)
@ -518,7 +551,9 @@ def output_derive_script(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode) ->
pubkeyhash = address_type.strip(coin.address_type, raw_address)
script = output_script_p2pkh(pubkeyhash)
if coin.bip115:
script += script_replay_protection_bip115(o.block_hash_bip115, o.block_height_bip115)
script += script_replay_protection_bip115(
o.block_hash_bip115, o.block_height_bip115
)
return script
elif address_type.check(coin.address_type_p2sh, raw_address):
@ -526,10 +561,12 @@ def output_derive_script(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode) ->
scripthash = address_type.strip(coin.address_type_p2sh, raw_address)
script = output_script_p2sh(scripthash)
if coin.bip115:
script += script_replay_protection_bip115(o.block_hash_bip115, o.block_height_bip115)
script += script_replay_protection_bip115(
o.block_hash_bip115, o.block_height_bip115
)
return script
raise SigningError(FailureType.DataError, 'Invalid address type')
raise SigningError(FailureType.DataError, "Invalid address type")
def get_address_for_change(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode):
@ -542,33 +579,40 @@ def get_address_for_change(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode):
elif o.script_type == OutputScriptType.PAYTOP2SHWITNESS:
input_script_type = InputScriptType.SPENDP2SHWITNESS
else:
raise SigningError(FailureType.DataError, 'Invalid script type')
return get_address(input_script_type, coin, node_derive(root, o.address_n), o.multisig)
raise SigningError(FailureType.DataError, "Invalid script type")
return get_address(
input_script_type, coin, node_derive(root, o.address_n), o.multisig
)
def output_is_change(o: TxOutputType, wallet_path: list, segwit_in: int) -> bool:
is_segwit = (o.script_type == OutputScriptType.PAYTOWITNESS or
o.script_type == OutputScriptType.PAYTOP2SHWITNESS)
is_segwit = (
o.script_type == OutputScriptType.PAYTOWITNESS
or o.script_type == OutputScriptType.PAYTOP2SHWITNESS
)
if is_segwit and o.amount > segwit_in:
# if the output is segwit, make sure it doesn't spend more than what the
# segwit inputs paid. this is to prevent user being tricked into
# creating ANYONECANSPEND outputs before full segwit activation.
return False
return (wallet_path is not None and
wallet_path == o.address_n[:-_BIP32_WALLET_DEPTH] and
o.address_n[-2] <= _BIP32_CHANGE_CHAIN and
o.address_n[-1] <= _BIP32_MAX_LAST_ELEMENT)
return (
wallet_path is not None
and wallet_path == o.address_n[:-_BIP32_WALLET_DEPTH]
and o.address_n[-2] <= _BIP32_CHANGE_CHAIN
and o.address_n[-1] <= _BIP32_MAX_LAST_ELEMENT
)
# Tx Inputs
# ===
def input_derive_script(coin: CoinInfo, i: TxInputType, pubkey: bytes, signature: bytes=None) -> bytes:
def input_derive_script(
coin: CoinInfo, i: TxInputType, pubkey: bytes, signature: bytes = None
) -> bytes:
if i.script_type == InputScriptType.SPENDADDRESS:
# p2pkh or p2sh
return input_script_p2pkh_or_p2sh(
pubkey, signature, get_hash_type(coin))
return input_script_p2pkh_or_p2sh(pubkey, signature, get_hash_type(coin))
if i.script_type == InputScriptType.SPENDP2SHWITNESS:
# p2wpkh or p2wsh using p2sh
@ -591,10 +635,11 @@ def input_derive_script(coin: CoinInfo, i: TxInputType, pubkey: bytes, signature
# p2sh multisig
signature_index = multisig_pubkey_index(i.multisig, pubkey)
return input_script_multisig(
i.multisig, signature, signature_index, get_hash_type(coin))
i.multisig, signature, signature_index, get_hash_type(coin)
)
else:
raise SigningError(FailureType.ProcessError, 'Invalid script type')
raise SigningError(FailureType.ProcessError, "Invalid script type")
def input_extract_wallet_path(txi: TxInputType, wallet_path: list) -> list:
@ -615,8 +660,9 @@ def input_check_wallet_path(txi: TxInputType, wallet_path: list) -> list:
return # there was a mismatch in Phase 1, ignore it now
address_n = txi.address_n[:-_BIP32_WALLET_DEPTH]
if wallet_path != address_n:
raise SigningError(FailureType.ProcessError,
'Transaction has changed during signing')
raise SigningError(
FailureType.ProcessError, "Transaction has changed during signing"
)
def node_derive(root: bip32.HDNode, address_n: list) -> bip32.HDNode:
@ -630,7 +676,9 @@ def address_n_matches_coin(address_n: list, coin: CoinInfo) -> bool:
return True # path is too short
if address_n[0] not in (44 | 0x80000000, 49 | 0x80000000, 84 | 0x80000000):
return True # path is not BIP44/49/84
return address_n[1] == (coin.slip44 | 0x80000000) # check whether coin_type matches slip44 value
return address_n[1] == (
coin.slip44 | 0x80000000
) # check whether coin_type matches slip44 value
def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes:
@ -640,10 +688,8 @@ def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes:
def is_change(
txo: TxOutputType,
wallet_path: list,
segwit_in: int,
multifp: MultisigFingerprint) -> bool:
txo: TxOutputType, wallet_path: list, segwit_in: int, multifp: MultisigFingerprint
) -> bool:
if txo.multisig:
if not multifp.matches(txo.multisig):
return False

@ -34,14 +34,14 @@ _TXSIZE_WITNESSSCRIPT = const(34)
class TxWeightCalculator:
def __init__(self, inputs_count: int, outputs_count: int):
self.inputs_count = inputs_count
self.counter = 4 * (
_TXSIZE_HEADER +
_TXSIZE_FOOTER +
self.ser_length_size(inputs_count) +
self.ser_length_size(outputs_count))
_TXSIZE_HEADER
+ _TXSIZE_FOOTER
+ self.ser_length_size(inputs_count)
+ self.ser_length_size(outputs_count)
)
self.segwit = False
def add_witness_header(self):
@ -53,26 +53,31 @@ class TxWeightCalculator:
def add_input(self, i: TxInputType):
if i.multisig:
multisig_script_size = (
_TXSIZE_MULTISIGSCRIPT +
len(i.multisig.pubkeys) * (1 + _TXSIZE_PUBKEY))
multisig_script_size = _TXSIZE_MULTISIGSCRIPT + len(i.multisig.pubkeys) * (
1 + _TXSIZE_PUBKEY
)
input_script_size = (
1 + # the OP_FALSE bug in multisig
i.multisig.m * (1 + _TXSIZE_SIGNATURE) +
self.op_push_size(multisig_script_size) +
multisig_script_size)
1
+ i.multisig.m * (1 + _TXSIZE_SIGNATURE) # the OP_FALSE bug in multisig
+ self.op_push_size(multisig_script_size)
+ multisig_script_size
)
else:
input_script_size = 1 + _TXSIZE_SIGNATURE + 1 + _TXSIZE_PUBKEY
self.counter += 4 * _TXSIZE_INPUT
if (i.script_type == InputScriptType.SPENDADDRESS or
i.script_type == InputScriptType.SPENDMULTISIG):
if (
i.script_type == InputScriptType.SPENDADDRESS
or i.script_type == InputScriptType.SPENDMULTISIG
):
input_script_size += self.ser_length_size(input_script_size)
self.counter += 4 * input_script_size
elif (i.script_type == InputScriptType.SPENDWITNESS or
i.script_type == InputScriptType.SPENDP2SHWITNESS):
elif (
i.script_type == InputScriptType.SPENDWITNESS
or i.script_type == InputScriptType.SPENDP2SHWITNESS
):
self.add_witness_header()
if i.script_type == InputScriptType.SPENDP2SHWITNESS:
if i.multisig:

@ -130,7 +130,7 @@ def bytearray_with_cap(cap: int) -> bytearray:
# ===
def get_tx_hash(w, double: bool=False, reverse: bool=False) -> bytes:
def get_tx_hash(w, double: bool = False, reverse: bool = False) -> bytes:
d = w.get_digest()
if double:
d = sha256(d).digest()

@ -21,7 +21,7 @@ async def verify_message(ctx, msg):
message = msg.message
address = msg.address
signature = msg.signature
coin_name = msg.coin_name or 'Bitcoin'
coin_name = msg.coin_name or "Bitcoin"
coin = coins.by_name(coin_name)
digest = message_digest(coin, message)
@ -37,12 +37,12 @@ async def verify_message(ctx, msg):
script_type = SPENDWITNESS # native segwit
signature = bytes([signature[0] - 8]) + signature[1:]
else:
raise wire.ProcessError('Invalid signature')
raise wire.ProcessError("Invalid signature")
pubkey = secp256k1.verify_recover(signature, digest)
if not pubkey:
raise wire.ProcessError('Invalid signature')
raise wire.ProcessError("Invalid signature")
if script_type == SPENDADDRESS:
addr = address_pkh(pubkey, coin)
@ -53,21 +53,21 @@ async def verify_message(ctx, msg):
elif script_type == SPENDWITNESS:
addr = address_p2wpkh(pubkey, coin.bech32_prefix)
else:
raise wire.ProcessError('Invalid signature')
raise wire.ProcessError("Invalid signature")
if addr != address:
raise wire.ProcessError('Invalid signature')
raise wire.ProcessError("Invalid signature")
await require_confirm_verify_message(ctx, address_short(coin, address), message)
return Success(message='Message verified')
return Success(message="Message verified")
async def require_confirm_verify_message(ctx, address, message):
text = Text('Confirm address')
text = Text("Confirm address")
text.mono(*split_address(address))
await require_confirm(ctx, text)
text = Text('Verify message')
text = Text("Verify message")
text.normal(*split_message(message))
await require_confirm(ctx, text)

@ -8,7 +8,7 @@ async def bootscreen():
while True:
try:
if not config.has_pin():
config.unlock(pin_to_int(''), show_pin_timeout)
config.unlock(pin_to_int(""), show_pin_timeout)
return
await lockscreen()
label = None
@ -17,7 +17,7 @@ async def bootscreen():
if config.unlock(pin_to_int(pin), show_pin_timeout):
return
else:
label = 'Wrong PIN, enter again'
label = "Wrong PIN, enter again"
except: # noqa: E722
pass
@ -28,9 +28,9 @@ async def lockscreen():
label = storage.get_label()
image = storage.get_homescreen()
if not label:
label = 'My TREZOR'
label = "My TREZOR"
if not image:
image = res.load('apps/homescreen/res/bg.toif')
image = res.load("apps/homescreen/res/bg.toif")
await ui.backlight_slide(ui.BACKLIGHT_DIM)
@ -40,9 +40,11 @@ async def lockscreen():
ui.display.bar_radius(40, 100, 160, 40, ui.TITLE_GREY, ui.BG, 4)
ui.display.bar_radius(42, 102, 156, 36, ui.BG, ui.TITLE_GREY, 4)
ui.display.text_center(ui.WIDTH // 2, 128, 'Locked', ui.BOLD, ui.TITLE_GREY, ui.BG)
ui.display.text_center(ui.WIDTH // 2, 128, "Locked", ui.BOLD, ui.TITLE_GREY, ui.BG)
ui.display.text_center(ui.WIDTH // 2 + 10, 220, 'Tap to unlock', ui.BOLD, ui.TITLE_GREY, ui.BG)
ui.display.text_center(
ui.WIDTH // 2 + 10, 220, "Tap to unlock", ui.BOLD, ui.TITLE_GREY, ui.BG
)
ui.display.icon(45, 202, res.load(ui.ICON_CLICK), ui.TITLE_GREY, ui.BG)
await ui.backlight_slide(ui.BACKLIGHT_NORMAL)

@ -15,6 +15,7 @@ import apps.wallet
import apps.ethereum
import apps.lisk
import apps.nem
if __debug__:
import apps.debug
else:
@ -43,5 +44,6 @@ utils.set_mode_unprivileged()
# run main event loop and specify which screen is the default
from apps.homescreen.homescreen import homescreen
workflow.startdefault(homescreen)
loop.run()

@ -70,6 +70,7 @@ async def dump_uvarint(writer, n):
# But this is harder in Python because we don't natively know the bit size of the number.
# So we have to branch on whether the number is negative.
def sint_to_uint(sint):
res = sint << 1
if sint < 0:
@ -114,11 +115,10 @@ class MessageType:
setattr(self, kw, kwargs[kw])
def __eq__(self, rhs):
return (self.__class__ is rhs.__class__ and
self.__dict__ == rhs.__dict__)
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
def __repr__(self):
return '<%s>' % self.__class__.__name__
return "<%s>" % self.__class__.__name__
class LimitedReader:
@ -191,7 +191,7 @@ async def load_message(reader, msg_type):
elif ftype is UnicodeType:
fvalue = bytearray(ivalue)
await reader.areadinto(fvalue)
fvalue = str(fvalue, 'utf8')
fvalue = str(fvalue, "utf8")
elif issubclass(ftype, MessageType):
fvalue = await load_message(LimitedReader(reader, ivalue), ftype)
else:
@ -247,7 +247,7 @@ async def dump_message(writer, msg):
await writer.awrite(svalue)
elif ftype is UnicodeType:
bvalue = bytes(svalue, 'utf8')
bvalue = bytes(svalue, "utf8")
await dump_uvarint(writer, len(bvalue))
await writer.awrite(bvalue)

@ -2,70 +2,70 @@ from trezorcrypto import AES
def AES_ECB_Encrypt(key: bytes) -> AES:
'''
"""
Create AES encryption context in ECB mode
'''
"""
return AES(AES.ECB | AES.Encrypt, key)
def AES_ECB_Decrypt(key: bytes) -> AES:
'''
"""
Create AES decryption context in ECB mode
'''
"""
return AES(AES.ECB | AES.Decrypt, key)
def AES_CBC_Encrypt(key: bytes, iv: bytes) -> AES:
'''
"""
Create AES encryption context in CBC mode
'''
"""
return AES(AES.CBC | AES.Encrypt, key, iv)
def AES_CBC_Decrypt(key: bytes, iv: bytes) -> AES:
'''
"""
Create AES decryption context in CBC mode
'''
"""
return AES(AES.CBC | AES.Decrypt, key, iv)
def AES_CFB_Encrypt(key: bytes, iv: bytes) -> AES:
'''
"""
Create AES encryption context in CFB mode
'''
"""
return AES(AES.CFB | AES.Encrypt, key, iv)
def AES_CFB_Decrypt(key: bytes, iv: bytes) -> AES:
'''
"""
Create AES decryption context in CFB mode
'''
"""
return AES(AES.CFB | AES.Decrypt, key, iv)
def AES_OFB_Encrypt(key: bytes, iv: bytes) -> AES:
'''
"""
Create AES encryption context in OFB mode
'''
"""
return AES(AES.OFB | AES.Encrypt, key, iv)
def AES_OFB_Decrypt(key: bytes, iv: bytes) -> AES:
'''
"""
Create AES decryption context in OFB mode
'''
"""
return AES(AES.OFB | AES.Decrypt, key, iv)
def AES_CTR_Encrypt(key: bytes) -> AES:
'''
"""
Create AES encryption context in CTR mode
'''
"""
return AES(AES.CTR | AES.Encrypt, key)
def AES_CTR_Decrypt(key: bytes) -> AES:
'''
"""
Create AES decryption context in CTR mode
'''
"""
return AES(AES.CTR | AES.Decrypt, key)

@ -5,7 +5,7 @@
from ubinascii import unhexlify
from ustruct import unpack
_b32alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567'
_b32alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
_b32tab = [ord(c) for c in _b32alphabet]
_b32rev = dict([(ord(v), k) for k, v in enumerate(_b32alphabet)])
@ -24,27 +24,30 @@ def encode(s: bytes) -> str:
# leftover bit of c1 and tack it onto c2. Then we take the 2 leftover
# bits of c2 and tack them onto c3. The shifts and masks are intended
# to give us values of exactly 5 bits in width.
c1, c2, c3 = unpack('!HHB', s[i * 5:(i + 1) * 5])
c1, c2, c3 = unpack("!HHB", s[i * 5 : (i + 1) * 5])
c2 += (c1 & 1) << 16 # 17 bits wide
c3 += (c2 & 3) << 8 # 10 bits wide
encoded += bytes([_b32tab[c1 >> 11], # bits 1 - 5
_b32tab[(c1 >> 6) & 0x1f], # bits 6 - 10
_b32tab[(c1 >> 1) & 0x1f], # bits 11 - 15
_b32tab[c2 >> 12], # bits 16 - 20 (1 - 5)
_b32tab[(c2 >> 7) & 0x1f], # bits 21 - 25 (6 - 10)
_b32tab[(c2 >> 2) & 0x1f], # bits 26 - 30 (11 - 15)
_b32tab[c3 >> 5], # bits 31 - 35 (1 - 5)
_b32tab[c3 & 0x1f], # bits 36 - 40 (1 - 5)
])
c3 += (c2 & 3) << 8 # 10 bits wide
encoded += bytes(
[
_b32tab[c1 >> 11], # bits 1 - 5
_b32tab[(c1 >> 6) & 0x1f], # bits 6 - 10
_b32tab[(c1 >> 1) & 0x1f], # bits 11 - 15
_b32tab[c2 >> 12], # bits 16 - 20 (1 - 5)
_b32tab[(c2 >> 7) & 0x1f], # bits 21 - 25 (6 - 10)
_b32tab[(c2 >> 2) & 0x1f], # bits 26 - 30 (11 - 15)
_b32tab[c3 >> 5], # bits 31 - 35 (1 - 5)
_b32tab[c3 & 0x1f], # bits 36 - 40 (1 - 5)
]
)
# Adjust for any leftover partial quanta
if leftover == 1:
encoded = encoded[:-6] + b'======'
encoded = encoded[:-6] + b"======"
elif leftover == 2:
encoded = encoded[:-4] + b'===='
encoded = encoded[:-4] + b"===="
elif leftover == 3:
encoded = encoded[:-3] + b'==='
encoded = encoded[:-3] + b"==="
elif leftover == 4:
encoded = encoded[:-1] + b'='
encoded = encoded[:-1] + b"="
return bytes(encoded).decode()
@ -53,11 +56,11 @@ def decode(s: str) -> bytes:
s = s.encode()
quanta, leftover = divmod(len(s), 8)
if leftover:
raise ValueError('Incorrect padding')
raise ValueError("Incorrect padding")
# Strip off pad characters from the right. We need to count the pad
# characters because this will tell us how many null bytes to remove from
# the end of the decoded string.
padchars = s.find(b'=')
padchars = s.find(b"=")
if padchars > 0:
padchars = len(s) - padchars
s = s[:-padchars]
@ -71,17 +74,17 @@ def decode(s: str) -> bytes:
for c in s:
val = _b32rev.get(c)
if val is None:
raise ValueError('Non-base32 digit found')
raise ValueError("Non-base32 digit found")
acc += _b32rev[c] << shift
shift -= 5
if shift < 0:
parts.append(unhexlify(('%010x' % acc).encode()))
parts.append(unhexlify(("%010x" % acc).encode()))
acc = 0
shift = 35
# Process the last, partial quanta
last = unhexlify(bytes('%010x' % acc, "ascii"))
last = unhexlify(bytes("%010x" % acc, "ascii"))
if padchars == 0:
last = b'' # No characters
last = b"" # No characters
elif padchars == 1:
last = last[:-1]
elif padchars == 3:
@ -91,6 +94,6 @@ def decode(s: str) -> bytes:
elif padchars == 6:
last = last[:-4]
else:
raise ValueError('Incorrect padding')
raise ValueError("Incorrect padding")
parts.append(last)
return b''.join(parts)
return b"".join(parts)

@ -14,15 +14,15 @@
#
# 58 character alphabet used
_alphabet = '123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz'
_alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
def encode(data: bytes) -> str:
'''
"""
Convert bytes to base58 encoded string.
'''
"""
origlen = len(data)
data = data.lstrip(b'\0')
data = data.lstrip(b"\0")
newlen = len(data)
p, acc = 1, 0
@ -30,18 +30,18 @@ def encode(data: bytes) -> str:
acc += p * c
p = p << 8
result = ''
result = ""
while acc > 0:
acc, mod = divmod(acc, 58)
result += _alphabet[mod]
return ''.join((c for c in reversed(result + _alphabet[0] * (origlen - newlen))))
return "".join((c for c in reversed(result + _alphabet[0] * (origlen - newlen))))
def decode(string: str) -> bytes:
'''
"""
Convert base58 encoded string to bytes.
'''
"""
origlen = len(string)
string = string.lstrip(_alphabet[0])
newlen = len(string)
@ -61,28 +61,32 @@ def decode(string: str) -> bytes:
def sha256d_32(data: bytes) -> bytes:
from .hashlib import sha256
return sha256(sha256(data).digest()).digest()[:4]
def groestl512d_32(data: bytes) -> bytes:
from .hashlib import groestl512
return groestl512(groestl512(data).digest()).digest()[:4]
def encode_check(data: bytes, digestfunc=sha256d_32) -> str:
'''
"""
Convert bytes to base58 encoded string, append checksum.
'''
"""
return encode(data + digestfunc(data))
def decode_check(string: str, digestfunc=sha256d_32) -> bytes:
'''
"""
Convert base58 encoded string to bytes and verify checksum.
'''
"""
result = decode(string)
digestlen = len(digestfunc(b''))
digestlen = len(digestfunc(b""))
result, check = result[:-digestlen], result[-digestlen:]
if check != digestfunc(result):
raise ValueError('Invalid checksum')
raise ValueError("Invalid checksum")
return result

@ -56,22 +56,23 @@ def bech32_create_checksum(hrp, data):
def bech32_encode(hrp, data):
"""Compute a Bech32 string given HRP and data values."""
combined = data + bech32_create_checksum(hrp, data)
return hrp + '1' + ''.join([CHARSET[d] for d in combined])
return hrp + "1" + "".join([CHARSET[d] for d in combined])
def bech32_decode(bech):
"""Validate a Bech32 string, and determine HRP and data."""
if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or
(bech.lower() != bech and bech.upper() != bech)):
if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or (
bech.lower() != bech and bech.upper() != bech
):
return (None, None)
bech = bech.lower()
pos = bech.rfind('1')
pos = bech.rfind("1")
if pos < 1 or pos + 7 > len(bech) or len(bech) > 90:
return (None, None)
if not all(x in CHARSET for x in bech[pos + 1:]):
if not all(x in CHARSET for x in bech[pos + 1 :]):
return (None, None)
hrp = bech[:pos]
data = [CHARSET.find(x) for x in bech[pos + 1:]]
data = [CHARSET.find(x) for x in bech[pos + 1 :]]
if not bech32_verify_checksum(hrp, data):
return (None, None)
return (hrp, data[:-6])

@ -20,7 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
CHARSET = 'qpzry9x8gf2tvdw0s3jn54khce6mua7l'
CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
ADDRESS_TYPE_P2KH = 0
ADDRESS_TYPE_P2SH = 8
@ -60,7 +60,7 @@ def b32decode(inputs):
def b32encode(inputs):
out = ''
out = ""
for char_code in inputs:
out += CHARSET[char_code]
return out
@ -92,13 +92,13 @@ def encode(prefix, version, payload):
payload = bytes([version]) + payload
payload = convertbits(payload, 8, 5)
checksum = calculate_checksum(prefix, payload)
return prefix + ':' + b32encode(payload + checksum)
return prefix + ":" + b32encode(payload + checksum)
def decode(prefix, addr):
addr = addr.lower()
decoded = b32decode(addr)
if not verify_checksum(prefix, decoded):
raise ValueError('Bad cashaddr checksum')
raise ValueError("Bad cashaddr checksum")
data = bytes(convertbits(decoded, 5, 8))
return data[0], data[1:-6]

@ -1,4 +1,3 @@
def encode_length(l: int) -> bytes:
if l < 0x80:
return bytes([l])
@ -11,14 +10,14 @@ def encode_length(l: int) -> bytes:
def encode_int(i: bytes) -> bytes:
i = i.lstrip(b'\x00')
i = i.lstrip(b"\x00")
if i[0] >= 0x80:
i = b'\x00' + i
return b'\x02' + encode_length(len(i)) + i
i = b"\x00" + i
return b"\x02" + encode_length(len(i)) + i
def encode_seq(seq: tuple) -> bytes:
res = b''
res = b""
for i in seq:
res += encode_int(i)
return b'\x30' + encode_length(len(res)) + res
return b"\x30" + encode_length(len(res)) + res

@ -1,6 +1,4 @@
class Hmac:
def __init__(self, key, msg, digestmod):
self.digestmod = digestmod
self.inner = digestmod()
@ -15,15 +13,15 @@ class Hmac:
self.update(msg)
def update(self, msg: bytes) -> None:
'''
"""
Update the context with data.
'''
"""
self.inner.update(msg)
def digest(self) -> bytes:
'''
"""
Returns the digest of processed data.
'''
"""
outer = self.digestmod()
outer.update(bytes((x ^ 0x5C) for x in self.key))
outer.update(self.inner.digest())
@ -31,7 +29,7 @@ class Hmac:
def new(key, msg, digestmod) -> Hmac:
'''
"""
Creates a HMAC context object.
'''
"""
return Hmac(key, msg, digestmod)

@ -1,7 +1,6 @@
def int_to_bytes(x: int) -> bytes:
if x == 0:
return b''
return b""
r = bytearray()
while x:
r.append(x % 256)
@ -17,7 +16,7 @@ def encode_length(l: int, is_list: bool) -> bytes:
bl = int_to_bytes(l)
return bytes([len(bl) + offset + 55]) + bl
else:
raise ValueError('Input too long')
raise ValueError("Input too long")
def encode(data, include_length=True) -> bytes:
@ -31,7 +30,7 @@ def encode(data, include_length=True) -> bytes:
else:
return encode_length(len(data), is_list=False) + data
elif isinstance(data, list):
output = b''
output = b""
for item in data:
output += encode(item)
if include_length:
@ -39,7 +38,7 @@ def encode(data, include_length=True) -> bytes:
else:
return output
else:
raise TypeError('Invalid input of type ' + str(type(data)))
raise TypeError("Invalid input of type " + str(type(data)))
def field_length(length: int, first_byte: bytearray) -> int:

@ -10,11 +10,11 @@ ERROR = const(40)
CRITICAL = const(50)
_leveldict = {
DEBUG: ('DEBUG', '32'),
INFO: ('INFO', '36'),
WARNING: ('WARNING', '33'),
ERROR: ('ERROR', '31'),
CRITICAL: ('CRITICAL', '1;31'),
DEBUG: ("DEBUG", "32"),
INFO: ("INFO", "36"),
WARNING: ("WARNING", "33"),
ERROR: ("ERROR", "31"),
CRITICAL: ("CRITICAL", "1;31"),
}
level = DEBUG
@ -24,9 +24,14 @@ color = True
def _log(name, mlevel, msg, *args):
if __debug__ and mlevel >= level:
if color:
fmt = '%d \x1b[35m%s\x1b[0m \x1b[' + _leveldict[mlevel][1] + 'm%s\x1b[0m ' + msg
fmt = (
"%d \x1b[35m%s\x1b[0m \x1b["
+ _leveldict[mlevel][1]
+ "m%s\x1b[0m "
+ msg
)
else:
fmt = '%d %s %s ' + msg
fmt = "%d %s %s " + msg
print(fmt % ((utime.ticks_us(), name, _leveldict[mlevel][0]) + args))
@ -47,7 +52,7 @@ def error(name, msg, *args):
def exception(name, exc):
_log(name, ERROR, 'exception:')
_log(name, ERROR, "exception:")
sys.print_exception(exc)

@ -1,11 +1,11 @@
'''
"""
Implements an event loop with cooperative multitasking and async I/O. Tasks in
the form of python coroutines (either plain generators or `async` functions) are
stepped through until completion, and can get asynchronously blocked by
`yield`ing or `await`ing a syscall.
See `schedule`, `run`, and syscalls `sleep`, `wait`, `signal` and `spawn`.
'''
"""
import utime
import utimeq
@ -22,16 +22,17 @@ _paused = {}
if __debug__:
# for performance stats
import array
log_delay_pos = 0
log_delay_rb_len = const(10)
log_delay_rb = array.array('i', [0] * log_delay_rb_len)
log_delay_rb = array.array("i", [0] * log_delay_rb_len)
def schedule(task, value=None, deadline=None):
'''
"""
Schedule task to be executed with `value` on given `deadline` (in
microseconds). Does not start the event loop itself, see `run`.
'''
"""
if deadline is None:
deadline = utime.ticks_us()
_queue.push(deadline, task, value)
@ -52,12 +53,12 @@ def close(task):
def run():
'''
"""
Loop forever, stepping through scheduled tasks and awaiting I/O events
inbetween. Use `schedule` first to add a coroutine to the task queue.
Tasks yield back to the scheduler on any I/O, usually by calling `await` on
a `Syscall`.
'''
"""
if __debug__:
global log_delay_pos
@ -98,7 +99,7 @@ def _step(task, value):
result = task.send(value)
except StopIteration as e:
if __debug__:
log.debug(__name__, 'finish: %s', task)
log.debug(__name__, "finish: %s", task)
except Exception as e:
if __debug__:
log.exception(__name__, e)
@ -109,16 +110,16 @@ def _step(task, value):
schedule(task)
else:
if __debug__:
log.error(__name__, 'unknown syscall: %s', result)
log.error(__name__, "unknown syscall: %s", result)
if after_step_hook:
after_step_hook()
class Syscall:
'''
"""
When tasks want to perform any I/O, or do any sort of communication with the
scheduler, they do so through instances of a class derived from `Syscall`.
'''
"""
def __iter__(self):
# support `yield from` or `await` on syscalls
@ -126,7 +127,7 @@ class Syscall:
class sleep(Syscall):
'''
"""
Pause current task and resume it after given delay. Although the delay is
given in microseconds, sub-millisecond precision is not guaranteed. Result
value is the calculated deadline.
@ -135,7 +136,7 @@ class sleep(Syscall):
>>> planned = await loop.sleep(1000 * 1000) # sleep for 1ms
>>> print('missed by %d us', utime.ticks_diff(utime.ticks_us(), planned))
'''
"""
def __init__(self, delay_us):
self.delay_us = delay_us
@ -146,7 +147,7 @@ class sleep(Syscall):
class wait(Syscall):
'''
"""
Pause current task, and resume only after a message on `msg_iface` is
received. Messages are received either from an USB interface, or the
touch display. Result value a tuple of message values.
@ -155,7 +156,7 @@ class wait(Syscall):
>>> hid_report, = await loop.wait(0xABCD) # await USB HID report
>>> event, x, y = await loop.wait(io.TOUCH) # await touch event
'''
"""
def __init__(self, msg_iface):
self.msg_iface = msg_iface
@ -168,7 +169,7 @@ _NO_VALUE = ()
class signal(Syscall):
'''
"""
Pause current task, and let other running task to resume it later with a
result value or an exception.
@ -181,7 +182,7 @@ class signal(Syscall):
>>> # in task #2:
>>> signal.send('hello from task #2')
>>> # prints in the next iteration of the event loop
'''
"""
def __init__(self):
self.value = _NO_VALUE
@ -210,7 +211,7 @@ class signal(Syscall):
class spawn(Syscall):
'''
"""
Execute one or more children tasks and wait until one of them exits.
Return value of `spawn` is the return value of task that triggered the
completion. By default, `spawn` returns after the first child completes, and
@ -232,7 +233,7 @@ class spawn(Syscall):
Note: You should not directly `yield` a `spawn` instance, see logic in
`spawn.__iter__` for explanation. Always use `await`.
'''
"""
def __init__(self, *children, exit_others=True):
self.children = children
@ -281,7 +282,6 @@ class spawn(Syscall):
class put(Syscall):
def __init__(self, ch, value=None):
self.ch = ch
self.value = value
@ -295,7 +295,6 @@ class put(Syscall):
class take(Syscall):
def __init__(self, ch):
self.ch = ch
@ -308,7 +307,6 @@ class take(Syscall):
class chan:
def __init__(self, id=None):
self.id = id
self.putters = []

@ -2,18 +2,38 @@ from trezor import ui
def pin_to_int(pin: str) -> int:
return int('1' + pin)
return int("1" + pin)
def show_pin_timeout(seconds: int, progress: int):
if progress == 0:
ui.display.bar(0, 0, ui.WIDTH, ui.HEIGHT, ui.BG)
ui.display.text_center(ui.WIDTH // 2, 37, 'Verifying PIN', ui.BOLD, ui.FG, ui.BG, ui.WIDTH)
ui.display.text_center(
ui.WIDTH // 2, 37, "Verifying PIN", ui.BOLD, ui.FG, ui.BG, ui.WIDTH
)
ui.display.loader(progress, 0, ui.FG, ui.BG)
if seconds == 0:
ui.display.text_center(ui.WIDTH // 2, ui.HEIGHT - 22, 'Done', ui.BOLD, ui.FG, ui.BG, ui.WIDTH)
ui.display.text_center(
ui.WIDTH // 2, ui.HEIGHT - 22, "Done", ui.BOLD, ui.FG, ui.BG, ui.WIDTH
)
elif seconds == 1:
ui.display.text_center(ui.WIDTH // 2, ui.HEIGHT - 22, '1 second left', ui.BOLD, ui.FG, ui.BG, ui.WIDTH)
ui.display.text_center(
ui.WIDTH // 2,
ui.HEIGHT - 22,
"1 second left",
ui.BOLD,
ui.FG,
ui.BG,
ui.WIDTH,
)
else:
ui.display.text_center(ui.WIDTH // 2, ui.HEIGHT - 22, '%d seconds left' % seconds, ui.BOLD, ui.FG, ui.BG, ui.WIDTH)
ui.display.text_center(
ui.WIDTH // 2,
ui.HEIGHT - 22,
"%d seconds left" % seconds,
ui.BOLD,
ui.FG,
ui.BG,
ui.WIDTH,
)
ui.display.refresh()

@ -5,16 +5,16 @@ except ImportError:
def load(name):
'''
"""
Loads resource of a given name as bytes.
'''
"""
return resdata[name]
def gettext(message):
'''
"""
Returns localized string. This function is aliased to _.
'''
"""
return message

@ -18,13 +18,12 @@ BORDER = const(4) # border size in pixels
class Button(LazyWidget):
def __init__(self, area: tuple, content: str, style: dict = ui.BTN_KEY):
self.area = area
self.content = content
self.normal_style = style['normal'] or ui.BTN_KEY['normal']
self.active_style = style['active'] or ui.BTN_KEY['active']
self.disabled_style = style['disabled'] or ui.BTN_KEY['disabled']
self.normal_style = style["normal"] or ui.BTN_KEY["normal"]
self.active_style = style["active"] or ui.BTN_KEY["active"]
self.disabled_style = style["disabled"] or ui.BTN_KEY["disabled"]
self.state = BTN_INITIAL
def enable(self):
@ -50,28 +49,24 @@ class Button(LazyWidget):
self.render_content(s, ax, ay, aw, ah)
def render_background(self, s, ax, ay, aw, ah):
radius = s['radius']
bg_color = s['bg-color']
border_color = s['border-color']
radius = s["radius"]
bg_color = s["bg-color"]
border_color = s["border-color"]
if border_color != bg_color:
# render border and background on top of it
display.bar_radius(ax, ay,
aw, ah,
border_color,
ui.BG,
radius)
display.bar_radius(ax + BORDER, ay + BORDER,
aw - BORDER * 2, ah - BORDER * 2,
bg_color,
border_color,
radius)
display.bar_radius(ax, ay, aw, ah, border_color, ui.BG, radius)
display.bar_radius(
ax + BORDER,
ay + BORDER,
aw - BORDER * 2,
ah - BORDER * 2,
bg_color,
border_color,
radius,
)
else:
# render only the background
display.bar_radius(ax, ay,
aw, ah,
bg_color,
ui.BG,
radius)
display.bar_radius(ax, ay, aw, ah, bg_color, ui.BG, radius)
def render_content(self, s, ax, ay, aw, ah):
c = self.content
@ -79,10 +74,10 @@ class Button(LazyWidget):
ty = ay + ah // 2 + 8
if isinstance(c, str):
display.text_center(
tx, ty, c, s['text-style'], s['fg-color'], s['bg-color'])
tx, ty, c, s["text-style"], s["fg-color"], s["bg-color"]
)
else:
display.icon(
tx - ICON // 2, ty - ICON, c, s['fg-color'], s['bg-color'])
display.icon(tx - ICON // 2, ty - ICON, c, s["fg-color"], s["bg-color"])
def touch(self, event, pos):
pos = rotate(pos)

@ -15,21 +15,20 @@ DEFAULT_CANCEL = res.load(ui.ICON_CANCEL)
class ConfirmDialog(Widget):
def __init__(self,
content,
confirm=DEFAULT_CONFIRM,
cancel=DEFAULT_CANCEL,
confirm_style=ui.BTN_CONFIRM,
cancel_style=ui.BTN_CANCEL):
def __init__(
self,
content,
confirm=DEFAULT_CONFIRM,
cancel=DEFAULT_CANCEL,
confirm_style=ui.BTN_CONFIRM,
cancel_style=ui.BTN_CANCEL,
):
self.content = content
if cancel is not None:
self.confirm = Button(
ui.grid(9, n_x=2), confirm, style=confirm_style)
self.cancel = Button(
ui.grid(8, n_x=2), cancel, style=cancel_style)
self.confirm = Button(ui.grid(9, n_x=2), confirm, style=confirm_style)
self.cancel = Button(ui.grid(8, n_x=2), cancel, style=cancel_style)
else:
self.confirm = Button(
ui.grid(4, n_x=1), confirm, style=confirm_style)
self.confirm = Button(ui.grid(4, n_x=1), confirm, style=confirm_style)
self.cancel = None
def render(self):
@ -56,12 +55,13 @@ _STOPPED = const(-2)
class HoldToConfirmDialog(Widget):
def __init__(self,
content,
hold='Hold to confirm',
button_style=ui.BTN_CONFIRM,
loader_style=ui.LDR_DEFAULT):
def __init__(
self,
content,
hold="Hold to confirm",
button_style=ui.BTN_CONFIRM,
loader_style=ui.LDR_DEFAULT,
):
self.content = content
self.button = Button(ui.grid(4, n_x=1), hold, style=button_style)
self.loader = Loader(style=loader_style)

@ -2,7 +2,6 @@ from trezor.ui import Widget
class Container(Widget):
def __init__(self, *children):
self.children = children

@ -9,11 +9,10 @@ HOST = const(1)
class EntrySelector(Widget):
def __init__(self, content):
self.content = content
self.device = Button(ui.grid(8, n_y=4, n_x=4, cells_x=4), 'Device')
self.host = Button(ui.grid(12, n_y=4, n_x=4, cells_x=4), 'Host')
self.device = Button(ui.grid(8, n_y=4, n_x=4, cells_x=4), "Device")
self.host = Button(ui.grid(12, n_y=4, n_x=4, cells_x=4), "Host")
def render(self):
self.device.render()

@ -8,11 +8,10 @@ _SHRINK_BY = const(2)
class Loader(ui.Widget):
def __init__(self, style=ui.LDR_DEFAULT):
self.target_ms = _TARGET_MS
self.normal_style = style['normal'] or ui.LDR_DEFAULT['normal']
self.active_style = style['active'] or ui.LDR_DEFAULT['active']
self.normal_style = style["normal"] or ui.LDR_DEFAULT["normal"]
self.active_style = style["active"] or ui.LDR_DEFAULT["active"]
self.start_ms = None
self.stop_ms = None
@ -47,15 +46,19 @@ class Loader(ui.Widget):
s = self.active_style
else:
s = self.normal_style
if s['icon'] is None:
ui.display.loader(
r, -24, s['fg-color'], s['bg-color'])
elif s['icon-fg-color'] is None:
ui.display.loader(
r, -24, s['fg-color'], s['bg-color'], res.load(s['icon']))
if s["icon"] is None:
ui.display.loader(r, -24, s["fg-color"], s["bg-color"])
elif s["icon-fg-color"] is None:
ui.display.loader(r, -24, s["fg-color"], s["bg-color"], res.load(s["icon"]))
else:
ui.display.loader(
r, -24, s['fg-color'], s['bg-color'], res.load(s['icon']), s['icon-fg-color'])
r,
-24,
s["fg-color"],
s["bg-color"],
res.load(s["icon"]),
s["icon-fg-color"],
)
def __iter__(self):
sleep = loop.sleep(1000000 // 30) # 30 fps

@ -6,7 +6,7 @@ from trezor.ui.button import BTN_CLICKED, ICON, Button
if __debug__:
from apps.debug import input_signal
MNEMONIC_KEYS = ('abc', 'def', 'ghi', 'jkl', 'mno', 'pqr', 'stu', 'vwx', 'yz')
MNEMONIC_KEYS = ("abc", "def", "ghi", "jkl", "mno", "pqr", "stu", "vwx", "yz")
def key_buttons(keys):
@ -24,7 +24,7 @@ def compute_mask(text: str) -> int:
class Input(Button):
def __init__(self, area: tuple, content: str='', word: str=''):
def __init__(self, area: tuple, content: str = "", word: str = ""):
super().__init__(area, content)
self.word = word
self.icon = None
@ -37,26 +37,26 @@ class Input(Button):
self.taint()
if content == word: # confirm button
self.enable()
self.normal_style = ui.BTN_KEY_CONFIRM['normal']
self.active_style = ui.BTN_KEY_CONFIRM['active']
self.normal_style = ui.BTN_KEY_CONFIRM["normal"]
self.active_style = ui.BTN_KEY_CONFIRM["active"]
self.icon = ui.ICON_CONFIRM
elif word: # auto-complete button
self.enable()
self.normal_style = ui.BTN_KEY['normal']
self.active_style = ui.BTN_KEY['active']
self.normal_style = ui.BTN_KEY["normal"]
self.active_style = ui.BTN_KEY["active"]
self.icon = ui.ICON_CLICK
else: # disabled button
self.disable()
self.icon = None
def render_content(self, s, ax, ay, aw, ah):
text_style = s['text-style']
fg_color = s['fg-color']
bg_color = s['bg-color']
text_style = s["text-style"]
fg_color = s["fg-color"]
bg_color = s["bg-color"]
p = self.pending # should we draw the pending marker?
t = self.content # input content
w = self.word[len(t):] # suggested word
w = self.word[len(t) :] # suggested word
i = self.icon # rendered icon
tx = ax + 24 # x-offset of the content
@ -79,12 +79,12 @@ class Input(Button):
class MnemonicKeyboard(ui.Widget):
def __init__(self, prompt: str=''):
def __init__(self, prompt: str = ""):
self.prompt = prompt
self.input = Input(ui.grid(1, n_x=4, n_y=4, cells_x=3), '', '')
self.back = Button(ui.grid(0, n_x=4, n_y=4),
res.load(ui.ICON_BACK),
style=ui.BTN_CLEAR)
self.input = Input(ui.grid(1, n_x=4, n_y=4, cells_x=3), "", "")
self.back = Button(
ui.grid(0, n_x=4, n_y=4), res.load(ui.ICON_BACK), style=ui.BTN_CLEAR
)
self.keys = key_buttons(MNEMONIC_KEYS)
self.pbutton = None # pending key button
self.pindex = 0 # index of current pending char in pbutton
@ -114,7 +114,7 @@ class MnemonicKeyboard(ui.Widget):
if self.input.touch(event, pos) == BTN_CLICKED:
# input press, either auto-complete or confirm
if word and content == word:
self.edit('')
self.edit("")
return content
else:
self.edit(word)
@ -133,7 +133,7 @@ class MnemonicKeyboard(ui.Widget):
return
def edit(self, content, button=None, index=0):
word = bip39.find_word(content) or ''
word = bip39.find_word(content) or ""
mask = bip39.complete_word(content)
self.pbutton = button

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save