Skip to content

Commit 4e7319e

Browse files
Make the .NET library NativeAOT compatible (github#81)
* Make the .NET library NativeAOT compatible - Enabled AOT analyzers - Disabled STJ reflection in the test project (to help vet NAOT correctness) - Use source generation for all types that may be serialized - Remove all use of anonymous types - Removed <autogenerated/> from the source generated code, as it was suppressing the analyzers - Added support for propagating StreamJsonRpc's tracing to the CopilotClient's ILogger. I used this for debugging and decided to leave it - Updated StreamJsonRpc to a newly published version on nuget to pick up NativeAOT fixes - Cleaned up some formatting in the session types generator, in particular using a file-scoped namespace and removing the top-level indentation * Add missing JsonSerializable attributes for NativeAOT * Fix JSON options in tests for AOT --------- Co-authored-by: Steve Sanderson <SteveSandersonMS@users.noreply.github.com>
1 parent f28a23e commit 4e7319e

File tree

9 files changed

+1262
-1034
lines changed

9 files changed

+1262
-1034
lines changed

dotnet/src/Client.cs

Lines changed: 113 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ namespace GitHub.Copilot.SDK;
4949
/// await session.SendAsync(new MessageOptions { Prompt = "Hello!" });
5050
/// </code>
5151
/// </example>
52-
public class CopilotClient : IDisposable, IAsyncDisposable
52+
public partial class CopilotClient : IDisposable, IAsyncDisposable
5353
{
5454
private readonly ConcurrentDictionary<string, CopilotSession> _sessions = new();
5555
private readonly CopilotClientOptions _options;
@@ -461,7 +461,7 @@ public async Task<PingResponse> PingAsync(string? message = null, CancellationTo
461461
var connection = await EnsureConnectedAsync(cancellationToken);
462462

463463
return await connection.Rpc.InvokeWithCancellationAsync<PingResponse>(
464-
"ping", [new { message }], cancellationToken);
464+
"ping", [new PingRequest { Message = message }], cancellationToken);
465465
}
466466

467467
/// <summary>
@@ -554,7 +554,7 @@ public async Task DeleteSessionAsync(string sessionId, CancellationToken cancell
554554
var connection = await EnsureConnectedAsync(cancellationToken);
555555

556556
var response = await connection.Rpc.InvokeWithCancellationAsync<DeleteSessionResponse>(
557-
"session.delete", [new { sessionId }], cancellationToken);
557+
"session.delete", [new DeleteSessionRequest(sessionId)], cancellationToken);
558558

559559
if (!response.Success)
560560
{
@@ -604,7 +604,7 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio
604604
{
605605
var expectedVersion = SdkProtocolVersion.GetVersion();
606606
var pingResponse = await connection.Rpc.InvokeWithCancellationAsync<PingResponse>(
607-
"ping", [new { message = (string?)null }], cancellationToken);
607+
"ping", [new PingRequest()], cancellationToken);
608608

609609
if (!pingResponse.ProtocolVersion.HasValue)
610610
{
@@ -754,23 +754,45 @@ private async Task<Connection> ConnectToServerAsync(Process? cliProcess, string?
754754
outputStream = networkStream;
755755
}
756756

757-
var rpc = new JsonRpc(new HeaderDelimitedMessageHandler(outputStream, inputStream, CreateFormatter()));
758-
rpc.AddLocalRpcTarget(new RpcHandler(this));
757+
var rpc = new JsonRpc(new HeaderDelimitedMessageHandler(
758+
outputStream,
759+
inputStream,
760+
CreateSystemTextJsonFormatter()))
761+
{
762+
TraceSource = new LoggerTraceSource(_logger),
763+
};
764+
765+
var handler = new RpcHandler(this);
766+
rpc.AddLocalRpcMethod("session.event", handler.OnSessionEvent);
767+
rpc.AddLocalRpcMethod("tool.call", handler.OnToolCall);
768+
rpc.AddLocalRpcMethod("permission.request", handler.OnPermissionRequest);
759769
rpc.StartListening();
760770
return new Connection(rpc, cliProcess, tcpClient, networkStream);
761771
}
762772

763-
[UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Using the Json source generator.")]
764-
[UnconditionalSuppressMessage("AOT", "IL3050", Justification = "Using the Json source generator.")]
765-
static IJsonRpcMessageFormatter CreateFormatter()
773+
[UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Using happy path from https://microsoft.github.io/vs-streamjsonrpc/docs/nativeAOT.html")]
774+
[UnconditionalSuppressMessage("AOT", "IL3050", Justification = "Using happy path from https://microsoft.github.io/vs-streamjsonrpc/docs/nativeAOT.html")]
775+
private static SystemTextJsonFormatter CreateSystemTextJsonFormatter() =>
776+
new SystemTextJsonFormatter() { JsonSerializerOptions = SerializerOptionsForMessageFormatter };
777+
778+
private static JsonSerializerOptions SerializerOptionsForMessageFormatter { get; } = CreateSerializerOptions();
779+
780+
private static JsonSerializerOptions CreateSerializerOptions()
766781
{
767782
var options = new JsonSerializerOptions(JsonSerializerDefaults.Web)
768783
{
769784
AllowOutOfOrderMetadataProperties = true,
770785
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
771786
};
772787

773-
return new SystemTextJsonFormatter() { JsonSerializerOptions = options };
788+
options.TypeInfoResolverChain.Add(ClientJsonContext.Default);
789+
options.TypeInfoResolverChain.Add(TypesJsonContext.Default);
790+
options.TypeInfoResolverChain.Add(CopilotSession.SessionJsonContext.Default);
791+
options.TypeInfoResolverChain.Add(SessionEventsJsonContext.Default);
792+
793+
options.MakeReadOnly();
794+
795+
return options;
774796
}
775797

776798
internal CopilotSession? GetSession(string sessionId) =>
@@ -803,9 +825,7 @@ public async ValueTask DisposeAsync()
803825

804826
private class RpcHandler(CopilotClient client)
805827
{
806-
[JsonRpcMethod("session.event")]
807-
public void OnSessionEvent(string sessionId,
808-
JsonElement? @event)
828+
public void OnSessionEvent(string sessionId, JsonElement? @event)
809829
{
810830
var session = client.GetSession(sessionId);
811831
if (session != null && @event != null)
@@ -818,7 +838,6 @@ public void OnSessionEvent(string sessionId,
818838
}
819839
}
820840

821-
[JsonRpcMethod("tool.call")]
822841
public async Task<ToolCallResponse> OnToolCall(string sessionId,
823842
string toolCallId,
824843
string toolName,
@@ -891,7 +910,7 @@ public async Task<ToolCallResponse> OnToolCall(string sessionId,
891910
// something we don't control? an error?)
892911
TextResultForLlm = result is JsonElement { ValueKind: JsonValueKind.String } je
893912
? je.GetString()!
894-
: JsonSerializer.Serialize(result, tool.JsonSerializerOptions),
913+
: JsonSerializer.Serialize(result, tool.JsonSerializerOptions.GetTypeInfo(typeof(object))),
895914
};
896915
return new ToolCallResponse(toolResultObject);
897916
}
@@ -908,7 +927,6 @@ public async Task<ToolCallResponse> OnToolCall(string sessionId,
908927
}
909928
}
910929

911-
[JsonRpcMethod("permission.request")]
912930
public async Task<PermissionRequestResponse> OnPermissionRequest(string sessionId, JsonElement permissionRequest)
913931
{
914932
var session = client.GetSession(sessionId);
@@ -959,7 +977,7 @@ public static string Escape(string arg)
959977
}
960978

961979
// Request/Response types for RPC
962-
private record CreateSessionRequest(
980+
internal record CreateSessionRequest(
963981
string? Model,
964982
string? SessionId,
965983
List<ToolDefinition>? Tools,
@@ -975,7 +993,7 @@ private record CreateSessionRequest(
975993
List<string>? SkillDirectories,
976994
List<string>? DisabledSkills);
977995

978-
private record ToolDefinition(
996+
internal record ToolDefinition(
979997
string Name,
980998
string? Description,
981999
JsonElement Parameters /* JSON schema */)
@@ -984,10 +1002,10 @@ public static ToolDefinition FromAIFunction(AIFunction function)
9841002
=> new ToolDefinition(function.Name, function.Description, function.JsonSchema);
9851003
}
9861004

987-
private record CreateSessionResponse(
1005+
internal record CreateSessionResponse(
9881006
string SessionId);
9891007

990-
private record ResumeSessionRequest(
1008+
internal record ResumeSessionRequest(
9911009
string SessionId,
9921010
List<ToolDefinition>? Tools,
9931011
ProviderConfig? Provider,
@@ -998,24 +1016,93 @@ private record ResumeSessionRequest(
9981016
List<string>? SkillDirectories,
9991017
List<string>? DisabledSkills);
10001018

1001-
private record ResumeSessionResponse(
1019+
internal record ResumeSessionResponse(
10021020
string SessionId);
10031021

1004-
private record GetLastSessionIdResponse(
1022+
internal record GetLastSessionIdResponse(
10051023
string? SessionId);
10061024

1007-
private record DeleteSessionResponse(
1025+
internal record DeleteSessionRequest(
1026+
string SessionId);
1027+
1028+
internal record DeleteSessionResponse(
10081029
bool Success,
10091030
string? Error);
10101031

1011-
private record ListSessionsResponse(
1032+
internal record ListSessionsResponse(
10121033
List<SessionMetadata> Sessions);
10131034

1014-
private record ToolCallResponse(
1035+
internal record ToolCallResponse(
10151036
ToolResultObject? Result);
10161037

1017-
private record PermissionRequestResponse(
1038+
internal record PermissionRequestResponse(
10181039
PermissionRequestResult Result);
1040+
1041+
/// <summary>Trace source that forwards all logs to the ILogger.</summary>
1042+
internal sealed class LoggerTraceSource : TraceSource
1043+
{
1044+
public LoggerTraceSource(ILogger logger) : base(nameof(LoggerTraceSource), SourceLevels.All)
1045+
{
1046+
Listeners.Clear();
1047+
Listeners.Add(new LoggerTraceListener(logger));
1048+
}
1049+
1050+
private sealed class LoggerTraceListener(ILogger logger) : TraceListener
1051+
{
1052+
public override void TraceEvent(TraceEventCache? eventCache, string source, TraceEventType eventType, int id, string? message) =>
1053+
logger.Log(MapLevel(eventType), "[{Source}] {Message}", source, message);
1054+
1055+
public override void TraceEvent(TraceEventCache? eventCache, string source, TraceEventType eventType, int id, string? format, params object?[]? args) =>
1056+
logger.Log(MapLevel(eventType), "[{Source}] {Message}", source, args is null || args.Length == 0 ? format : string.Format(format ?? "", args));
1057+
1058+
public override void TraceData(TraceEventCache? eventCache, string source, TraceEventType eventType, int id, object? data) =>
1059+
logger.Log(MapLevel(eventType), "[{Source}] {Data}", source, data);
1060+
1061+
public override void TraceData(TraceEventCache? eventCache, string source, TraceEventType eventType, int id, params object?[]? data) =>
1062+
logger.Log(MapLevel(eventType), "[{Source}] {Data}", source, data is null ? null : string.Join(", ", data));
1063+
1064+
public override void Write(string? message) =>
1065+
logger.LogTrace("{Message}", message);
1066+
1067+
public override void WriteLine(string? message) =>
1068+
logger.LogTrace("{Message}", message);
1069+
1070+
private static LogLevel MapLevel(TraceEventType eventType) => eventType switch
1071+
{
1072+
TraceEventType.Critical => LogLevel.Critical,
1073+
TraceEventType.Error => LogLevel.Error,
1074+
TraceEventType.Warning => LogLevel.Warning,
1075+
TraceEventType.Information => LogLevel.Information,
1076+
TraceEventType.Verbose => LogLevel.Debug,
1077+
_ => LogLevel.Trace
1078+
};
1079+
}
1080+
}
1081+
1082+
[JsonSourceGenerationOptions(
1083+
JsonSerializerDefaults.Web,
1084+
AllowOutOfOrderMetadataProperties = true,
1085+
NumberHandling = JsonNumberHandling.AllowReadingFromString,
1086+
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)]
1087+
[JsonSerializable(typeof(CreateSessionRequest))]
1088+
[JsonSerializable(typeof(CreateSessionResponse))]
1089+
[JsonSerializable(typeof(CustomAgentConfig))]
1090+
[JsonSerializable(typeof(DeleteSessionRequest))]
1091+
[JsonSerializable(typeof(DeleteSessionResponse))]
1092+
[JsonSerializable(typeof(GetLastSessionIdResponse))]
1093+
[JsonSerializable(typeof(ListSessionsResponse))]
1094+
[JsonSerializable(typeof(PermissionRequestResponse))]
1095+
[JsonSerializable(typeof(PermissionRequestResult))]
1096+
[JsonSerializable(typeof(ProviderConfig))]
1097+
[JsonSerializable(typeof(ResumeSessionRequest))]
1098+
[JsonSerializable(typeof(ResumeSessionResponse))]
1099+
[JsonSerializable(typeof(SessionMetadata))]
1100+
[JsonSerializable(typeof(SystemMessageConfig))]
1101+
[JsonSerializable(typeof(ToolCallResponse))]
1102+
[JsonSerializable(typeof(ToolDefinition))]
1103+
[JsonSerializable(typeof(ToolResultAIContent))]
1104+
[JsonSerializable(typeof(ToolResultObject))]
1105+
internal partial class ClientJsonContext : JsonSerializerContext;
10191106
}
10201107

10211108
// Must inherit from AIContent as a signal to MEAI to avoid JSON-serializing the

0 commit comments

Comments
 (0)