Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved cancellation #186

Merged
merged 3 commits into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Directory.Build.props
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<Project>
<PropertyGroup Label="Language">
<LangVersion>latest</LangVersion>
<LangVersion>preview</LangVersion>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ public sealed class NearShareSender(ConnectedDevicesPlatform platform)

async Task<SenderStateMachine> PrepareTransferInternalAsync(EndpointInfo endpoint, CancellationToken cancellationToken)
{
var session = await Platform.ConnectAsync(endpoint, options: new() { TransportUpgraded = TransportUpgraded });
var session = await Platform.ConnectAsync(endpoint, options: new() { TransportUpgraded = TransportUpgraded }, cancellationToken);

Guid operationId = Guid.NewGuid();

HandshakeHandler handshake = new(Platform);
using var handShakeChannel = await session.StartClientChannelAsync(NearShareHandshakeApp.Id, NearShareHandshakeApp.Name, handshake, cancellationToken);
using var handShakeChannel = await session.StartClientChannelAsync(handshake, cancellationToken);
var handshakeResultMsg = await handshake.Execute(operationId);

// ToDo: CorrelationVector
Expand All @@ -49,8 +49,11 @@ public async Task SendFilesAsync(CdpDevice device, IReadOnlyList<CdpFileProvider
await senderStateMachine.SendFilesAsync(files, progress, cancellationToken);
}

sealed class HandshakeHandler(ConnectedDevicesPlatform cdp) : CdpAppBase(cdp)
sealed class HandshakeHandler(ConnectedDevicesPlatform cdp) : CdpAppBase(cdp), ICdpAppId
{
public static string Id { get; } = NearShareHandshakeApp.Id;
public static string Name { get; } = NearShareHandshakeApp.Name;

readonly TaskCompletionSource<CdpMessage> _promise = new();

public Task<CdpMessage> Execute(Guid operationId)
Expand Down
9 changes: 6 additions & 3 deletions lib/ShortDev.Microsoft.ConnectedDevices/CdpSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
public ConnectedDevicesPlatform Platform { get; }
public SessionId SessionId { get; private set; }

public PeerCapabilities HostCapabilities { get; internal set; } = 0;

Check warning on line 22 in lib/ShortDev.Microsoft.ConnectedDevices/CdpSession.cs

View workflow job for this annotation

GitHub Actions / Analyze (csharp)

Member 'HostCapabilities' is explicitly initialized to its default value (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1805)
public PeerCapabilities ClientCapabilities { get; internal set; } = 0;

Check warning on line 23 in lib/ShortDev.Microsoft.ConnectedDevices/CdpSession.cs

View workflow job for this annotation

GitHub Actions / Analyze (csharp)

Member 'ClientCapabilities' is explicitly initialized to its default value (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1805)

public CdpDeviceInfo? DeviceInfo => _connectHandler.DeviceInfo;

Expand Down Expand Up @@ -72,7 +72,7 @@
), out _);
}

internal static async Task<CdpSession> ConnectClientAsync(ConnectedDevicesPlatform platform, CdpSocket socket, ConnectOptions? options = null)
internal static async Task<CdpSession> ConnectClientAsync(ConnectedDevicesPlatform platform, CdpSocket socket, ConnectOptions? options = null, CancellationToken cancellationToken = default)
{
var session = _sessionRegistry.Create(localSessionId => new(
platform,
Expand All @@ -85,7 +85,7 @@
if (options is not null)
connectHandler.UpgradeHandler.Upgraded += options.TransportUpgraded;

await connectHandler.ConnectAsync(socket);
await connectHandler.ConnectAsync(socket, cancellationToken: cancellationToken);

return session;
}
Expand Down Expand Up @@ -178,12 +178,15 @@
}
#endregion

public Task<CdpChannel> StartClientChannelAsync<TApp>(TApp handler, CancellationToken cancellationToken = default) where TApp : CdpAppBase, ICdpAppId
=> StartClientChannelAsync(TApp.Id, TApp.Name, handler, cancellationToken);

public async Task<CdpChannel> StartClientChannelAsync(string appId, string appName, CdpAppBase handler, CancellationToken cancellationToken = default)
{
if (_channelHandler is not ClientChannelHandler clientChannelHandler)
throw new InvalidOperationException("Session is not a client");

var socket = await Platform.CreateSocketAsync(_connectHandler.UpgradeHandler.RemoteEndpoint);
var socket = await Platform.CreateSocketAsync(_connectHandler.UpgradeHandler.RemoteEndpoint, cancellationToken);
return await clientChannelHandler.CreateChannelAsync(appId, appName, handler, socket, cancellationToken);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,19 @@ await Task.WhenAll(_transportMap.Values
}
}

public async Task<CdpSession> ConnectAsync([NotNull] EndpointInfo endpoint, ConnectOptions? options = null)
public async Task<CdpSession> ConnectAsync([NotNull] EndpointInfo endpoint, ConnectOptions? options = null, CancellationToken cancellationToken = default)
{
var socket = await CreateSocketAsync(endpoint).ConfigureAwait(false);
return await CdpSession.ConnectClientAsync(this, socket, options).ConfigureAwait(false);
var socket = await CreateSocketAsync(endpoint, cancellationToken).ConfigureAwait(false);
return await CdpSession.ConnectClientAsync(this, socket, options, cancellationToken).ConfigureAwait(false);
}

internal async Task<CdpSocket> CreateSocketAsync(EndpointInfo endpoint)
internal async Task<CdpSocket> CreateSocketAsync(EndpointInfo endpoint, CancellationToken cancellationToken = default)
{
if (TryGetKnownSocket(endpoint, out var knownSocket))
return knownSocket;

var transport = TryGetTransport(endpoint.TransportType) ?? throw new InvalidOperationException($"No single transport found for type {endpoint.TransportType}");
var socket = await transport.ConnectAsync(endpoint).ConfigureAwait(false);
var socket = await transport.ConnectAsync(endpoint, cancellationToken).ConfigureAwait(false);
ReceiveLoop(socket);
return socket;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,21 @@
using ShortDev.Microsoft.ConnectedDevices.Messages.Connection.DeviceInfo;
using ShortDev.Microsoft.ConnectedDevices.Session.Upgrade;
using ShortDev.Microsoft.ConnectedDevices.Transports;
using System.Runtime.CompilerServices;

namespace ShortDev.Microsoft.ConnectedDevices.Session.Connection;
internal sealed class ClientConnectHandler(CdpSession session, ClientUpgradeHandler upgradeHandler) : ConnectHandler(session, upgradeHandler)
{
readonly ClientUpgradeHandler _clientUpgradeHandler = upgradeHandler;
readonly ILogger _logger = session.Platform.CreateLogger<ClientConnectHandler>();

TaskCompletionSource? _promise;
public Task ConnectAsync(CdpSocket socket, bool upgradeSupported = true)
ConnectionTask? _promise;
public async Task ConnectAsync(CdpSocket socket, bool upgradeSupported = true, CancellationToken cancellationToken = default)
{
if (_promise != null)
throw new InvalidOperationException("Already connecting");
cancellationToken.ThrowIfCancellationRequested();

_promise = new();
if (Interlocked.CompareExchange(ref _promise, new(cancellationToken), null) is not null)
throw new InvalidOperationException("Already connecting");

CommonHeader header = new()
{
Expand Down Expand Up @@ -52,11 +53,14 @@ public Task ConnectAsync(CdpSocket socket, bool upgradeSupported = true)

_session.SendMessage(socket, header, writer);

return _promise.Task;
await _promise;
}

protected override void HandleMessageInternal(CdpSocket socket, CommonHeader header, ConnectionHeader connectionHeader, ref EndianReader reader)
{
if (_promise?.CancellationToken.IsCancellationRequested == true)
return;

if (connectionHeader.MessageType == ConnectionType.ConnectResponse)
{
if (_session.Cryptor != null)
Expand Down Expand Up @@ -135,7 +139,7 @@ async void PrepareSession(CdpSocket socket)
try
{
var oldSocket = socket;
socket = await _clientUpgradeHandler.RequestUpgradeAsync(oldSocket);
socket = await _clientUpgradeHandler.UpgradeAsync(oldSocket);
oldSocket.Dispose();
}
catch (Exception ex)
Expand All @@ -146,26 +150,21 @@ async void PrepareSession(CdpSocket socket)

try
{
SendAuthDone(socket);
EndianWriter writer = new(Endianness.BigEndian);
new ConnectionHeader()
{
ConnectionMode = ConnectionMode.Proximal,
MessageType = ConnectionType.AuthDoneRequest
}.Write(writer);

header.Flags = 0;
_session.SendMessage(socket, header, writer);
}
catch (Exception ex)
{
_promise?.TrySetException(ex);
}
}

void SendAuthDone(CdpSocket socket)
{
EndianWriter writer = new(Endianness.BigEndian);
new ConnectionHeader()
{
ConnectionMode = ConnectionMode.Proximal,
MessageType = ConnectionType.AuthDoneRequest
}.Write(writer);

header.Flags = 0;
_session.SendMessage(socket, header, writer);
}
}

void HandleDeviceAuthResponse(CdpSocket socket, CommonHeader header)
Expand Down Expand Up @@ -212,4 +211,26 @@ void HandleAuthDoneResponse(CdpSocket socket, ref EndianReader reader)

_promise?.TrySetResult();
}

sealed class ConnectionTask
{
readonly TaskCompletionSource _promise = new();

public CancellationToken CancellationToken { get; }
public ConnectionTask(CancellationToken cancellationToken)
{
CancellationToken = cancellationToken;

cancellationToken.Register(() => _promise.TrySetCanceled());
}

public void TrySetResult()
=> _promise.TrySetResult();

public void TrySetException(Exception ex)
=> _promise.TrySetException(ex);

public TaskAwaiter GetAwaiter()
=> _promise.Task.GetAwaiter();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,15 @@ protected override bool TryHandleConnectInternal(CdpSocket socket, ConnectionHea
static readonly IReadOnlyList<EndpointMetadata> UpgradeEndpoints = [EndpointMetadata.Tcp];

UpgradeInstance? _currentUpgrade;
public async ValueTask<CdpSocket> RequestUpgradeAsync(CdpSocket oldSocket)
public async ValueTask<CdpSocket> UpgradeAsync(CdpSocket oldSocket)
{
if (_currentUpgrade != null)
if (Interlocked.CompareExchange(ref _currentUpgrade, new(), null) is not null)
throw new InvalidOperationException("Only a single upgrade may occur at the same time");

_currentUpgrade = new();
try
{
_logger.SendingUpgradeRequest(_currentUpgrade.Id, UpgradeEndpoints);
SendUpgradeRequest(oldSocket, _currentUpgrade.Id, UpgradeEndpoints);
return await _currentUpgrade.Promise.Task;
}
finally
{
_currentUpgrade = null;
}

void SendUpgradeRequest(CdpSocket socket, Guid upgradeId, IReadOnlyList<EndpointMetadata> endpoints)
{
CommonHeader header = new()
{
Type = MessageType.Connect
Expand All @@ -73,11 +63,17 @@ void SendUpgradeRequest(CdpSocket socket, Guid upgradeId, IReadOnlyList<Endpoint

new UpgradeRequest()
{
UpgradeId = upgradeId,
Endpoints = endpoints
UpgradeId = _currentUpgrade.Id,
Endpoints = UpgradeEndpoints
}.Write(writer);

_session.SendMessage(socket, header, writer);
_session.SendMessage(oldSocket, header, writer);

return await _currentUpgrade;
}
finally
{
_currentUpgrade = null;
}
}

Expand Down Expand Up @@ -107,19 +103,15 @@ async void FindNewEndpoint()
if (_currentUpgrade == null)
return;

_currentUpgrade.NewSocket = tasks.FirstOrDefault(x => x != null);
if (_currentUpgrade.NewSocket == null)
{
_currentUpgrade.Promise.TrySetCanceled();
if (!_currentUpgrade.TryChooseSocket(tasks.FirstOrDefault(x => x != null)))
return;
}

SendUpgradFinalization(oldSocket);

// Cancel after timeout if upgrade has not finished yet
await Task.Delay(UpgradeInstance.Timeout);

_currentUpgrade?.Promise.TrySetCanceled();
_currentUpgrade?.TrySetCanceled();
}
}

Expand Down Expand Up @@ -176,7 +168,7 @@ void HandleUpgradeFailure(ref EndianReader reader)
{
var msg = HResultPayload.Parse(ref reader);

_currentUpgrade?.Promise.TrySetException(
_currentUpgrade?.TrySetException(
new Exception($"Transport upgrade failed with HResult {msg.HResult} (hresult: {HResultPayload.HResultToString(msg.HResult)}, errorCode: {HResultPayload.ErrorCodeToString(msg.HResult)})")
);
}
Expand All @@ -195,19 +187,41 @@ void HandleTransportConfirmation(CdpSocket socket, ref EndianReader reader)
RemoteEndpoint = socket.Endpoint;

// Complete promise
_currentUpgrade.Promise.TrySetResult(socket);
_currentUpgrade.TrySetResult(socket);
}

sealed class UpgradeInstance
{
public static readonly TimeSpan Timeout = TimeSpan.FromSeconds(2);

public Guid Id { get; } = Guid.NewGuid();
public TaskCompletionSource<CdpSocket> Promise { get; } = new();

readonly TaskCompletionSource<CdpSocket> _promise = new();
public bool TrySetCanceled()
=> _promise.TrySetCanceled();

public bool TrySetResult(CdpSocket socket)
=> _promise.TrySetResult(socket);

public bool TrySetException(Exception ex)
=> _promise.TrySetException(ex);

public TaskAwaiter<CdpSocket> GetAwaiter()
=> Promise.Task.GetAwaiter();
=> _promise.Task.GetAwaiter();

CdpSocket? _newSocket;
public bool TryChooseSocket(CdpSocket? newSocket)
{
if (newSocket is null)
{
_promise.TrySetCanceled();
return false;
}

return Interlocked.CompareExchange(ref _newSocket, newSocket, null) is null;
}

public CdpSocket? NewSocket { get; set; }
public CdpSocket? NewSocket
=> _newSocket;
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
namespace ShortDev.Microsoft.ConnectedDevices.Transports.Bluetooth;
public sealed class BluetoothTransport(IBluetoothHandler handler) : ICdpTransport, ICdpDiscoverableTransport
{
public CdpTransportType TransportType { get; } = CdpTransportType.Rfcomm;
readonly IBluetoothHandler _handler = handler;

public IBluetoothHandler Handler { get; } = handler;
public CdpTransportType TransportType { get; } = CdpTransportType.Rfcomm;
public bool IsEnabled => _handler.IsEnabled;

public event DeviceConnectedEventHandler? DeviceConnected;
public async Task Listen(CancellationToken cancellationToken)
{
await Handler.ListenRfcommAsync(
await _handler.ListenRfcommAsync(
new RfcommOptions()
{
ServiceId = Constants.RfcommServiceId,
Expand All @@ -19,21 +20,21 @@ await Handler.ListenRfcommAsync(
);
}

public async Task<CdpSocket> ConnectAsync(EndpointInfo endpoint)
=> await Handler.ConnectRfcommAsync(endpoint, new RfcommOptions()
public async Task<CdpSocket> ConnectAsync(EndpointInfo endpoint, CancellationToken cancellationToken = default)
=> await _handler.ConnectRfcommAsync(endpoint, new RfcommOptions()
{
ServiceId = Constants.RfcommServiceId,
ServiceName = Constants.RfcommServiceName,
SocketConnected = (socket) => DeviceConnected?.Invoke(this, socket)
});
}, cancellationToken);

public async Task Advertise(LocalDeviceInfo deviceInfo, CancellationToken cancellationToken)
{
await Handler.AdvertiseBLeBeaconAsync(
await _handler.AdvertiseBLeBeaconAsync(
new AdvertiseOptions()
{
ManufacturerId = Constants.BLeBeaconManufacturerId,
BeaconData = new BLeBeacon(deviceInfo.Type, Handler.MacAddress, deviceInfo.Name)
BeaconData = new BLeBeacon(deviceInfo.Type, _handler.MacAddress, deviceInfo.Name)
},
cancellationToken
);
Expand All @@ -42,7 +43,7 @@ await Handler.AdvertiseBLeBeaconAsync(
public event DeviceDiscoveredEventHandler? DeviceDiscovered;
public async Task Discover(CancellationToken cancellationToken)
{
await Handler.ScanBLeAsync(new()
await _handler.ScanBLeAsync(new()
{
OnDeviceDiscovered = (advertisement, rssi) =>
{
Expand All @@ -65,5 +66,5 @@ public void Dispose()
}

public EndpointInfo GetEndpoint()
=> new(TransportType, Handler.MacAddress.ToStringFormatted(), Constants.RfcommServiceId);
=> new(TransportType, _handler.MacAddress.ToStringFormatted(), Constants.RfcommServiceId);
}
Loading
Loading