Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions crates/bindings-csharp/BSATN.Codegen/Type.cs
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ public MemberDeclaration(ISymbol member, ITypeSymbol type, DiagReporter diag)
public MemberDeclaration(IFieldSymbol field, DiagReporter diag)
: this(field, field.Type, diag) { }

public string Identifier => EscapeIdentifier(Name);

public static string GenerateBsatnFields(
Accessibility visibility,
IEnumerable<MemberDeclaration> members
Expand All @@ -431,7 +433,7 @@ IEnumerable<MemberDeclaration> members
return string.Join(
"\n ",
members.Select(m =>
$"{visStr} static readonly {m.Type.ToBSATNString()} {m.Name}{TypeUse.BsatnFieldSuffix} = new();"
$"{visStr} static readonly {m.Type.ToBSATNString()} {m.Identifier}{TypeUse.BsatnFieldSuffix} = new();"
)
);
}
Expand All @@ -442,7 +444,7 @@ public static string GenerateDefs(IEnumerable<MemberDeclaration> members) =>
// we can't use nameof(m.Type.BsatnFieldName) because the bsatn field name differs from the logical name
// assigned in the type.
members.Select(m =>
$"new(\"{m.Name}\", {m.Name}{TypeUse.BsatnFieldSuffix}.GetAlgebraicType(registrar))"
$"new(\"{m.Name}\", {m.Identifier}{TypeUse.BsatnFieldSuffix}.GetAlgebraicType(registrar))"
)
);
}
Expand All @@ -462,6 +464,12 @@ public abstract record BaseTypeDeclaration<M>
public readonly TypeKind Kind;
public readonly EquatableArray<M> Members;

/// <summary>
/// Returns the escaped version of ShortName for use in generated C# code where the type name
/// appears as an identifier (e.g., in IEquatable&lt;T&gt; or as a base type reference).
/// </summary>
public string ShortNameIdentifier => EscapeIdentifier(ShortName);

protected abstract M ConvertMember(int index, IFieldSymbol field, DiagReporter diag);

public BaseTypeDeclaration(GeneratorAttributeSyntaxContext context, DiagReporter diag)
Expand Down Expand Up @@ -557,7 +565,7 @@ public Scope.Extensions ToExtensions()

var bsatnDecls = Members.Cast<MemberDeclaration>();

extensions.BaseTypes.Add($"System.IEquatable<{ShortName}>");
extensions.BaseTypes.Add($"System.IEquatable<{ShortNameIdentifier}>");

if (Kind is TypeKind.Sum)
{
Expand All @@ -569,10 +577,10 @@ public Scope.Extensions ToExtensions()
// To avoid this, we append an underscore to the field name.
// In most cases the field name shouldn't matter anyway as you'll idiomatically use pattern matching to extract the value.
$$"""
public sealed record {{m.Name}}({{m.Type.Name}} {{m.Name}}_) : {{ShortName}}
public sealed record {{m.Identifier}}({{m.Type.Name}} {{m.Identifier}}_) : {{ShortNameIdentifier}}
{
public override string ToString() =>
$"{{m.Name}}({ SpacetimeDB.BSATN.StringUtil.GenericToString({{m.Name}}_) })";
$"{{m.Name}}({ SpacetimeDB.BSATN.StringUtil.GenericToString({{m.Identifier}}_) })";
}

"""
Expand All @@ -585,7 +593,7 @@ public override string ToString() =>
{{string.Join(
"\n ",
bsatnDecls.Select((m, i) =>
$"{i} => new {m.Name}({m.Name}{TypeUse.BsatnFieldSuffix}.Read(reader)),"
$"{i} => new {m.Identifier}({m.Identifier}{TypeUse.BsatnFieldSuffix}.Read(reader)),"
)
)}}
_ => throw new System.InvalidOperationException("Invalid tag value, this state should be unreachable.")
Expand All @@ -597,9 +605,9 @@ public override string ToString() =>
{{string.Join(
"\n",
bsatnDecls.Select((m, i) => $"""
case {m.Name}(var inner):
case {m.Identifier}(var inner):
writer.Write((byte){i});
{m.Name}{TypeUse.BsatnFieldSuffix}.Write(writer, inner);
{m.Identifier}{TypeUse.BsatnFieldSuffix}.Write(writer, inner);
break;
"""))}}
}
Expand All @@ -615,7 +623,7 @@ public override string ToString() =>
var hashName = $"___hash{member.Name}";

return $"""
case {member.Name}(var inner):
case {member.Identifier}(var inner):
{member.Type.GetHashCodeStatement("inner", hashName)}
return {hashName};
""";
Expand All @@ -634,14 +642,14 @@ public override string ToString() =>
public void ReadFields(System.IO.BinaryReader reader) {
{{string.Join(
"\n",
bsatnDecls.Select(m => $" {m.Name} = BSATN.{m.Name}{TypeUse.BsatnFieldSuffix}.Read(reader);")
bsatnDecls.Select(m => $" {m.Identifier} = BSATN.{m.Identifier}{TypeUse.BsatnFieldSuffix}.Read(reader);")
)}}
}

public void WriteFields(System.IO.BinaryWriter writer) {
{{string.Join(
"\n",
bsatnDecls.Select(m => $" BSATN.{m.Name}{TypeUse.BsatnFieldSuffix}.Write(writer, {m.Name});")
bsatnDecls.Select(m => $" BSATN.{m.Identifier}{TypeUse.BsatnFieldSuffix}.Write(writer, {m.Identifier});")
)}}
}

Expand All @@ -661,7 +669,7 @@ object SpacetimeDB.BSATN.IStructuralReadWrite.GetSerializer() {
public override string ToString() =>
$"{{ShortName}} {{start}} {{string.Join(
", ",
bsatnDecls.Select(m => $$"""{{m.Name}} = {SpacetimeDB.BSATN.StringUtil.GenericToString({{m.Name}})}""")
bsatnDecls.Select(m => $$"""{{m.Name}} = {SpacetimeDB.BSATN.StringUtil.GenericToString({{m.Identifier}})}""")
)}} {{end}}";
"""
);
Expand All @@ -680,7 +688,7 @@ public override string ToString() =>
var declHashName = (MemberDeclaration decl) => $"___hash{decl.Name}";

getHashCode = $$"""
{{string.Join("\n", bsatnDecls.Select(decl => decl.Type.GetHashCodeStatement(decl.Name, declHashName(decl))))}}
{{string.Join("\n", bsatnDecls.Select(decl => decl.Type.GetHashCodeStatement(decl.Identifier, declHashName(decl))))}}
return {{JoinOrValue(
" ^\n ",
bsatnDecls.Select(declHashName),
Expand Down Expand Up @@ -735,7 +743,7 @@ public override int GetHashCode()
public bool Equals({{fullNameMaybeRef}} that)
{
{{(Scope.IsStruct ? "" : "if (((object?)that) == null) { return false; }\n ")}}
{{string.Join("\n", bsatnDecls.Select(decl => decl.Type.EqualsStatement($"this.{decl.Name}", $"that.{decl.Name}", declEqualsName(decl))))}}
{{string.Join("\n", bsatnDecls.Select(decl => decl.Type.EqualsStatement($"this.{decl.Identifier}", $"that.{decl.Identifier}", declEqualsName(decl))))}}
return {{JoinOrValue(
" &&\n ",
bsatnDecls.Select(declEqualsName),
Expand Down
13 changes: 13 additions & 0 deletions crates/bindings-csharp/BSATN.Codegen/Utils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,26 @@ public readonly record struct EquatableArray<T>(ImmutableArray<T> Array) : IEnum
.AddMemberOptions(SymbolDisplayMemberOptions.IncludeContainingType)
.AddMiscellaneousOptions(
SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier
| SymbolDisplayMiscellaneousOptions.EscapeKeywordIdentifiers
);

public static string SymbolToName(ISymbol symbol)
{
return symbol.ToDisplayString(SymbolFormat);
}

public static string EscapeIdentifier(string name)
{
if (name.Length > 0 && name[0] == '@')
{
return name;
}

var kind = SyntaxFacts.GetKeywordKind(name);
var contextualKind = SyntaxFacts.GetContextualKeywordKind(name);
return kind != SyntaxKind.None || contextualKind != SyntaxKind.None ? $"@{name}" : name;
}

public static void RegisterSourceOutputs(
this IncrementalValuesProvider<Scope.Extensions> methods,
IncrementalGeneratorInitializationContext context
Expand Down
81 changes: 81 additions & 0 deletions crates/bindings-csharp/Codegen.Tests/Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,87 @@ public static async Task SettingsAndExplicitNames()
AssertGeneratedCodeDoesNotUseInternalBound(compilationAfterGen);
}

[Fact]
public static async Task CSharpKeywordIdentifiersAreEscapedInGeneratedCode()
{
var fixture = await Fixture.Compile("server");

const string source = """
using SpacetimeDB;

[SpacetimeDB.Table]
public partial struct KeywordTable
{
[SpacetimeDB.PrimaryKey]
public ulong @class;

public int @params;
}

[SpacetimeDB.Table(Accessor = "event")]
public partial struct AccessorKeywordTable
{
[SpacetimeDB.PrimaryKey]
[SpacetimeDB.Index.BTree(Accessor = "params")]
public int Id;
}

[SpacetimeDB.Table]
public partial struct @class
{
[SpacetimeDB.PrimaryKey]
public int Id;
}

public static partial class KeywordApis
{
[SpacetimeDB.Reducer]
public static void KeywordReducer(ReducerContext ctx, int @params, string @class)
{
_ = @params;
_ = @class;
}

[SpacetimeDB.Reducer]
public static void @class(ReducerContext ctx)
{
}

[SpacetimeDB.Procedure]
public static int KeywordProcedure(ProcedureContext ctx, int @params, int @class)
{
return @params + @class;
}

[SpacetimeDB.Procedure]
public static void @params(ProcedureContext ctx)
{
}
}
""";

var parseOptions = new CSharpParseOptions(fixture.SampleCompilation.LanguageVersion);
var tree = CSharpSyntaxTree.ParseText(source, parseOptions, path: "KeywordNames.cs");
var compilation = fixture.SampleCompilation.AddSyntaxTrees(tree);

var driver = CSharpGeneratorDriver.Create(
[
new SpacetimeDB.Codegen.Type().AsSourceGenerator(),
new SpacetimeDB.Codegen.Module().AsSourceGenerator(),
],
driverOptions: new(
disabledOutputs: IncrementalGeneratorOutputKind.None,
trackIncrementalGeneratorSteps: true
),
parseOptions: parseOptions
);

var runResult = driver.RunGenerators(compilation).GetRunResult();
var compilationAfterGen = compilation.AddSyntaxTrees(runResult.GeneratedTrees);

Assert.Empty(GetCompilationErrors(compilationAfterGen));
}

[Fact]
public static async Task TestDiagnostics()
{
Expand Down
Loading
Loading