Skip to content

Commit

Permalink
feat: cooperative submarine claims (#84)
Browse files Browse the repository at this point in the history
* cooperative submarine claim

* test: make `checkTxOutAddress` work without blinding key

* test: ignore initial `swap.created` status

* fix: log message
  • Loading branch information
jackstar12 authored Feb 13, 2024
1 parent db4bd87 commit ac98154
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 129 deletions.
47 changes: 44 additions & 3 deletions boltz/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,15 @@ type RefundSwapRequest struct {
Index int `json:"index"`
}

type SwapClaimDetails struct {
PubNonce HexString `json:"pubNonce"`
TransactionHash HexString `json:"transactionHash"`
Preimage HexString `json:"preimage"`
PublicKey HexString `json:"publicKey"`

Error string `json:"error"`
}

type GetInvoiceAmountResponse struct {
InvoiceAmount uint64 `json:"invoiceAmount"`
Error string `json:"error"`
Expand Down Expand Up @@ -237,9 +246,13 @@ type ClaimReverseSwapRequest struct {
}

type PartialSignature struct {
PubNonce string `json:"pubNonce"`
PartialSignature string `json:"partialSignature"`
PubNonce HexString `json:"pubNonce"`
PartialSignature HexString `json:"partialSignature"`

Error string `json:"error"`
}

type ErrorMessage struct {
Error string `json:"error"`
}

Expand Down Expand Up @@ -373,6 +386,31 @@ func (boltz *Boltz) GetInvoiceAmount(swapId string) (*GetInvoiceAmountResponse,
return &response, err
}

func (boltz *Boltz) GetSwapClaimDetails(swapId string) (*SwapClaimDetails, error) {
if boltz.DisablePartialSignatures {
return nil, errors.New("partial signatures are disabled")
}
var response SwapClaimDetails
err := boltz.sendGetRequest(fmt.Sprintf("/v2/swap/submarine/%s/claim", swapId), &response)

if response.Error != "" {
return nil, Error(errors.New(response.Error))
}

return &response, err
}

func (boltz *Boltz) SendSwapClaimSignature(swapId string, signature *PartialSignature) error {
var response ErrorMessage
err := boltz.sendPostRequest(fmt.Sprintf("/v2/swap/submarine/%s/claim", swapId), signature, &response)

if response.Error != "" {
return Error(errors.New(response.Error))
}

return err
}

func (boltz *Boltz) SetInvoice(swapId string, invoice string) (*SetInvoiceResponse, error) {
var response SetInvoiceResponse
err := boltz.sendPostRequest(fmt.Sprintf("/v2/swap/submarine/%s/invoice", swapId), SetInvoiceRequest{Invoice: invoice}, &response)
Expand Down Expand Up @@ -432,7 +470,10 @@ func (boltz *Boltz) sendPostRequest(endpoint string, requestBody interface{}, re
return err
}

return unmarshalJson(res.Body, &response)
if err := unmarshalJson(res.Body, &response); err != nil {
return fmt.Errorf("could not parse boltz response with status %d: %v", res.StatusCode, err)
}
return nil
}

func unmarshalJson(body io.ReadCloser, response interface{}) error {
Expand Down
5 changes: 2 additions & 3 deletions boltz/btc.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func getPrevoutFetcher(tx *wire.MsgTx, outputs []OutputDetails) txscript.PrevOut
return txscript.NewMultiPrevOutFetcher(previous)
}

func btcTaprootHash(transaction Transaction, outputs []OutputDetails, index int) ([32]byte, error) {
func btcTaprootHash(transaction Transaction, outputs []OutputDetails, index int) ([]byte, error) {
tx := transaction.(*BtcTransaction).MsgTx()

previous := make(map[wire.OutPoint]*wire.TxOut)
Expand All @@ -90,14 +90,13 @@ func btcTaprootHash(transaction Transaction, outputs []OutputDetails, index int)
prevoutFetcher := getPrevoutFetcher(tx, outputs)
sigHashes := txscript.NewTxSigHashes(tx, prevoutFetcher)

hash, err := txscript.CalcTaprootSignatureHash(
return txscript.CalcTaprootSignatureHash(
sigHashes,
sigHashType,
tx,
index,
prevoutFetcher,
)
return [32]byte(hash), err
}

func constructBtcTransaction(network *Network, outputs []OutputDetails, outputAddressRaw string, fee uint64) (Transaction, error) {
Expand Down
1 change: 0 additions & 1 deletion boltz/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ var swapUpdateEventStrings = map[string]SwapUpdateEvent{
var CompletedStatus = []string{
InvoiceSettled.String(),
TransactionClaimed.String(),
TransactionClaimPending.String(),
}

var FailedStatus = []string{
Expand Down
8 changes: 6 additions & 2 deletions boltz/liquid.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ func NewLiquidTxFromHex(hexString string, ourOutputBlindingKey *btcec.PrivateKey
}

func (transaction *LiquidTransaction) FindVout(network *Network, addressToFind string) (uint32, uint64, error) {
if transaction.OurOutputBlindingKey == nil {
return 0, 0, errors.New("No blinding key set")
}
info, err := address.FromConfidential(addressToFind)
if err != nil {
return 0, 0, err
Expand All @@ -72,7 +75,7 @@ func (transaction *LiquidTransaction) VSize() uint64 {
return uint64(transaction.SerializeSize(false, true)) + uint64(math.Ceil(float64(witnessSize)/4))
}

func liquidTaprootHash(transaction *liquidtx.Transaction, network *Network, outputs []OutputDetails, index int, cooperative bool) [32]byte {
func liquidTaprootHash(transaction *liquidtx.Transaction, network *Network, outputs []OutputDetails, index int, cooperative bool) []byte {
var leafHash *chainhash.Hash
if !cooperative {
output := outputs[index]
Expand All @@ -91,7 +94,7 @@ func liquidTaprootHash(transaction *liquidtx.Transaction, network *Network, outp
values = append(values, out.Value)
}

return transaction.HashForWitnessV1(
hash := transaction.HashForWitnessV1(
index,
scripts,
assets,
Expand All @@ -101,6 +104,7 @@ func liquidTaprootHash(transaction *liquidtx.Transaction, network *Network, outp
leafHash,
nil,
)
return hash[:]
}

func constructLiquidTransaction(network *Network, outputs []OutputDetails, outputAddressRaw string, fee uint64) (Transaction, error) {
Expand Down
47 changes: 30 additions & 17 deletions boltz/musig.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package boltz

import (
"encoding/hex"
"bytes"
"errors"
"fmt"

Expand Down Expand Up @@ -37,30 +37,43 @@ func NewSigningSession(outputs []OutputDetails, idx int) (*MusigSession, error)
return &MusigSession{session, outputs, idx}, nil
}

func (session *MusigSession) Finalize(transaction Transaction, network *Network, boltzSignature *PartialSignature) error {
partialSignature, err := hex.DecodeString(boltzSignature.PartialSignature)
if err != nil {
return err
func (session *MusigSession) Sign(hash []byte, boltzNonce []byte) (*PartialSignature, error) {
if len(hash) != 32 {
return nil, fmt.Errorf("invalid hash length %d", len(hash))
}

nonce, err := hex.DecodeString(boltzSignature.PubNonce)
if err != nil {
return err
if len(boltzNonce) != 66 {
return nil, fmt.Errorf("invalid nonce lenth %d", len(boltzNonce))
}

if len(nonce) != 66 {
return errors.New("invalid nonce length")
all, err := session.RegisterPubNonce([66]byte(boltzNonce))
if err != nil {
return nil, err
}
if !all {
return nil, errors.New("could not combine nonces")
}

all, err := session.RegisterPubNonce([66]byte(nonce))
ourNonce := session.PublicNonce()

partial, err := session.Session.Sign([32]byte(hash))
if err != nil {
return err
return nil, err
}
if !all {
return errors.New("could not combine nonces")

b := bytes.NewBuffer(nil)
if err := partial.Encode(b); err != nil {
return nil, err
}

var hash [32]byte
return &PartialSignature{
PubNonce: HexString(ourNonce[:]),
PartialSignature: HexString(b.Bytes()),
}, nil
}

func (session *MusigSession) Finalize(transaction Transaction, network *Network, boltzSignature *PartialSignature) (err error) {
var hash []byte
isLiquid := session.outputs[session.idx].SwapTree.isLiquid
if isLiquid {
hash = liquidTaprootHash(&transaction.(*LiquidTransaction).Transaction, network, session.outputs, session.idx, true)
Expand All @@ -71,13 +84,13 @@ func (session *MusigSession) Finalize(transaction Transaction, network *Network,
return err
}

_, err = session.Sign(hash)
_, err = session.Sign(hash, boltzSignature.PubNonce)
if err != nil {
return err
}

s := &secp256k1.ModNScalar{}
s.SetByteSlice(partialSignature)
s.SetByteSlice(boltzSignature.PartialSignature)
partial := musig2.NewPartialSignature(s, nil)
haveFinal, err := session.CombineSig(&partial)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions boltz/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ func (output *OutputDetails) IsRefund() bool {
return len(output.Preimage) == 0
}

func NewTxFromHex(hexString string, ourOutputBlindingKey *btcec.PrivateKey) (Transaction, error) {
if ourOutputBlindingKey != nil {
func NewTxFromHex(currency Currency, hexString string, ourOutputBlindingKey *btcec.PrivateKey) (Transaction, error) {
if currency == CurrencyLiquid {
liquidTx, err := NewLiquidTxFromHex(hexString, ourOutputBlindingKey)
if err == nil {
return liquidTx, nil
Expand Down
84 changes: 55 additions & 29 deletions cmd/boltzd/boltzd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package main

import (
"bytes"
"context"
"fmt"
"net"
Expand Down Expand Up @@ -159,6 +160,11 @@ func swapStream(t *testing.T, client client.Boltz, swapId string) nextFunc {
select {
case status, ok := <-updates:
if ok {
// ignore initial swap.created message
created := boltz.SwapCreated.String()
if status.Swap.GetStatus() == created || status.ReverseSwap.GetStatus() == created {
continue
}
if state == status.Swap.GetState() || state == status.ReverseSwap.GetState() {
return status
}
Expand Down Expand Up @@ -208,32 +214,52 @@ func TestGetInfo(t *testing.T) {
}
}

func checkTxOutAddress(t *testing.T, chain onchain.Onchain, pair boltz.Currency, txId string, outAddress string) {
currency, err := chain.GetCurrency(pair)
require.NoError(t, err)
txHex, err := currency.Tx.GetTxHex(txId)
func checkTxOutAddress(t *testing.T, chain onchain.Onchain, currency boltz.Currency, txId string, outAddress string, cooperative bool) {
transaction, err := chain.GetTransaction(currency, txId, nil)
require.NoError(t, err)

if pair == boltz.CurrencyBtc {
tx, err := boltz.NewBtcTxFromHex(txHex)
require.NoError(t, err)
if tx, ok := transaction.(*boltz.BtcTransaction); ok {

decoded, err := btcutil.DecodeAddress(outAddress, &chaincfg.RegressionNetParams)
require.NoError(t, err)
script, err := txscript.PayToAddrScript(decoded)
require.NoError(t, err)
require.Equal(t, tx.MsgTx().TxOut[0].PkScript, script)
} else if pair == boltz.CurrencyLiquid {
tx, err := boltz.NewLiquidTxFromHex(txHex, nil)
require.NoError(t, err)
for _, input := range tx.MsgTx().TxIn {
if cooperative {
require.Len(t, input.Witness, 1)
} else {
require.Greater(t, len(input.Witness), 1)
}
}

script, err := address.ToOutputScript(outAddress)
require.NoError(t, err)
for _, output := range tx.Outputs {
if len(output.Script) == 0 {
continue
if outAddress != "" {
decoded, err := btcutil.DecodeAddress(outAddress, &chaincfg.RegressionNetParams)
require.NoError(t, err)
script, err := txscript.PayToAddrScript(decoded)
require.NoError(t, err)
for _, output := range tx.MsgTx().TxOut {
if bytes.Equal(output.PkScript, script) {
return
}
}
require.Equal(t, output.Script, script)
require.Fail(t, "could not find output address in transaction")
}
} else if tx, ok := transaction.(*boltz.LiquidTransaction); ok {
for _, input := range tx.Inputs {
if cooperative {
require.Len(t, input.Witness, 1)
} else {
require.Greater(t, len(input.Witness), 1)
}
}
if outAddress != "" {
script, err := address.ToOutputScript(outAddress)
require.NoError(t, err)
for _, output := range tx.Outputs {
if len(output.Script) == 0 {
continue
}
if bytes.Equal(output.Script, script) {
return
}
}
require.Fail(t, "could not find output address in transaction")
}
}
}
Expand Down Expand Up @@ -365,12 +391,10 @@ func TestSwap(t *testing.T) {
})
require.NoError(t, err)

next := swapStream(t, client, swap.Id)
next(boltzrpc.SwapState_PENDING)

test.SendToAddress(tc.cli, swap.Address, 100000)
test.MineBlock()

next := swapStream(t, client, swap.Id)
info := next(boltzrpc.SwapState_SUCCESSFUL)
checkSwap(t, info.Swap)
})
Expand Down Expand Up @@ -439,7 +463,7 @@ func TestSwap(t *testing.T) {

require.Equal(t, int(refundFee), int(*infos[1].OnchainFee)+int(*infos[2].OnchainFee))

checkTxOutAddress(t, chain, from, infos[0].RefundTransactionId, refundAddress)
checkTxOutAddress(t, chain, from, infos[0].RefundTransactionId, refundAddress, false)

refundFee, err = chain.GetTransactionFee(from, infos[0].RefundTransactionId)
require.NoError(t, err)
Expand Down Expand Up @@ -468,6 +492,8 @@ func TestSwap(t *testing.T) {
refundFee, err := chain.GetTransactionFee(from, info.RefundTransactionId)
require.NoError(t, err)
require.Equal(t, int(refundFee), int(*info.OnchainFee))

checkTxOutAddress(t, chain, from, info.RefundTransactionId, "", true)
})

t.Run("AddressRequired", func(t *testing.T) {
Expand Down Expand Up @@ -499,10 +525,10 @@ func TestReverseSwap(t *testing.T) {
recover bool
disablePartials bool
}{
{desc: "BTC/Normal", to: boltzrpc.Currency_Btc},
{desc: "BTC/Normal", to: boltzrpc.Currency_Btc, disablePartials: true},
{desc: "BTC/ZeroConf", to: boltzrpc.Currency_Btc, zeroConf: true, external: true},
{desc: "BTC/Recover", to: boltzrpc.Currency_Btc, zeroConf: true, recover: true},
{desc: "Liquid/Normal", to: boltzrpc.Currency_Liquid},
{desc: "Liquid/Normal", to: boltzrpc.Currency_Liquid, disablePartials: true},
{desc: "Liquid/ZeroConf", to: boltzrpc.Currency_Liquid, zeroConf: true, external: true},
{desc: "Liquid/Recover", to: boltzrpc.Currency_Liquid, zeroConf: true, recover: true},
}
Expand All @@ -513,7 +539,7 @@ func TestReverseSwap(t *testing.T) {
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
cfg := loadConfig(t)
cfg.Boltz.DisablePartialSignatures = true
cfg.Boltz.DisablePartialSignatures = tc.disablePartials
client, _, stop := setup(t, cfg, "")
cfg.Node = node
chain := onchain.Onchain{
Expand Down Expand Up @@ -592,7 +618,7 @@ func TestReverseSwap(t *testing.T) {
if tc.external {
require.Equal(t, addr, info.ReverseSwap.ClaimAddress)
}
checkTxOutAddress(t, chain, currency, info.ReverseSwap.ClaimTransactionId, info.ReverseSwap.ClaimAddress)
checkTxOutAddress(t, chain, currency, info.ReverseSwap.ClaimTransactionId, info.ReverseSwap.ClaimAddress, !tc.disablePartials)

stop()
})
Expand Down
Loading

0 comments on commit ac98154

Please sign in to comment.