Skip to content

Commit 513689a

Browse files
committed
Improve detection of IStructId namespace from compilation
The existing approach of using analyzer config options breaks if there are transitive project references involved, since the options will contain potentially the wrong namespace (the one for the project being built rather than the one where IStructId actually exists). This changes the approach to looking up the types by their metadata name plus a codegen attribute which we assume users won't be using for their own code (even if they do happen to use `IStructId` for some other purpose).
1 parent a727008 commit 513689a

12 files changed

Lines changed: 89 additions & 62 deletions

src/StructId.Analyzer/AnalysisExtensions.cs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ public static CSharpParseOptions GetParseOptions(this Compilation compilation)
2929
=> (CSharpParseOptions?)compilation.SyntaxTrees.FirstOrDefault()?.Options ??
3030
CSharpParseOptions.Default.WithLanguageVersion(LanguageVersion.Latest);
3131

32+
public static bool IsGeneratedByStructId(this ISymbol symbol)
33+
=> symbol.GetAttributes().Any(a
34+
=> a.AttributeClass?.Name == "GeneratedCodeAttribute" &&
35+
a.ConstructorArguments.Select(c => c.Value).OfType<string>().Any(v => v == nameof(StructId)));
36+
3237
/// <summary>
3338
/// Checks whether the <paramref name="this"/> type inherits or implements the
3439
/// <paramref name="baseTypeOrInterface"/> type, even if it's a generic type.
@@ -62,12 +67,6 @@ @this is INamedTypeSymbol namedActual &&
6267
return Is(@this.BaseType, baseTypeOrInterface, looseGenerics);
6368
}
6469

65-
public static string GetStructIdNamespace(this AnalyzerConfigOptions options)
66-
=> options.TryGetValue("build_property.StructIdNamespace", out var ns) && !string.IsNullOrEmpty(ns) ? ns : "StructId";
67-
68-
public static IncrementalValueProvider<string> GetStructIdNamespace(this IncrementalValueProvider<AnalyzerConfigOptionsProvider> options)
69-
=> options.Select((x, _) => x.GlobalOptions.TryGetValue("build_property.StructIdNamespace", out var ns) ? ns : "StructId");
70-
7170
public static bool ImplementsExplicitly(this INamedTypeSymbol namedTypeSymbol, INamedTypeSymbol interfaceTypeSymbol)
7271
{
7372
if (interfaceTypeSymbol.IsUnboundGenericType && interfaceTypeSymbol.TypeParameters.Length == 1)
@@ -156,13 +155,19 @@ public static string ToFileName(this ITypeSymbol type)
156155

157156
public static bool IsStructId(this ITypeSymbol type) => type.AllInterfaces.Any(x => x.Name == "IStructId");
158157

158+
public static bool IsValueTemplate(this INamedTypeSymbol symbol)
159+
=> symbol.GetAttributes().Any(IsValueTemplate);
160+
159161
public static bool IsValueTemplate(this AttributeData attribute)
160162
=> attribute.AttributeClass?.Name == "TValue" ||
161163
attribute.AttributeClass?.Name == "TValueAttribute";
162164

163165
public static bool IsValueTemplate(this AttributeSyntax attribute)
164166
=> attribute.Name.ToString() == "TValue" || attribute.Name.ToString() == "TValueAttribute";
165167

168+
public static bool IsStructIdTemplate(this INamedTypeSymbol symbol)
169+
=> symbol.GetAttributes().Any(IsStructIdTemplate);
170+
166171
public static bool IsStructIdTemplate(this AttributeData attribute)
167172
=> attribute.AttributeClass?.Name == "TStructId" ||
168173
attribute.AttributeClass?.Name == "TStructIdAttribute";

src/StructId.Analyzer/BaseGenerator.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,8 @@ protected record struct TemplateArgs(INamedTypeSymbol TSelf, INamedTypeSymbol TV
2626

2727
public virtual void Initialize(IncrementalGeneratorInitializationContext context)
2828
{
29-
var structIdNamespace = context.AnalyzerConfigOptionsProvider.GetStructIdNamespace();
30-
3129
var known = context.CompilationProvider
32-
.Combine(structIdNamespace)
33-
.Select((x, _) => new KnownTypes(x.Left, x.Right));
30+
.Select((x, _) => new KnownTypes(x));
3431

3532
// Locate the required type
3633
var types = context.CompilationProvider

src/StructId.Analyzer/CodeTemplate.cs

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ public static string Apply(string template, string structIdType, string valueTyp
3131

3232
public static string Apply(string template, string valueType, bool normalizeWhitespace = false)
3333
{
34-
var applied = ApplyImpl(Parse(template), valueType);
34+
var applied = ApplyValueImpl(Parse(template), valueType);
3535

3636
return normalizeWhitespace ?
3737
applied.NormalizeWhitespace().ToFullString().Trim() :
3838
applied.ToFullString().Trim();
3939
}
4040

41-
public static SyntaxNode ApplyValue(this SyntaxNode node, INamedTypeSymbol valueType) => ApplyImpl(node, valueType.ToFullName());
41+
public static SyntaxNode ApplyValue(this SyntaxNode node, INamedTypeSymbol valueType) => ApplyValueImpl(node, valueType.ToFullName());
4242

4343
public static SyntaxNode Apply(this SyntaxNode node, INamedTypeSymbol structId)
4444
{
@@ -59,7 +59,7 @@ public static SyntaxNode Apply(this SyntaxNode node, INamedTypeSymbol structId)
5959
return ApplyImpl(root, structId.Name, tid, targetNamespace, corens);
6060
}
6161

62-
static SyntaxNode ApplyImpl(this SyntaxNode node, string valueType)
62+
static SyntaxNode ApplyValueImpl(this SyntaxNode node, string valueType)
6363
{
6464
var root = node.SyntaxTree.GetCompilationUnitRoot();
6565
if (root == null)
@@ -194,7 +194,7 @@ bool IsFileLocal(TypeDeclarationSyntax node) =>
194194
!node.AttributeLists.Any(list => list.Attributes.Any(a => a.IsValueTemplate()));
195195
}
196196

197-
class TemplateRewriter(string tself, string tid) : CSharpSyntaxRewriter
197+
class TemplateRewriter(string tself, string tvalue) : CSharpSyntaxRewriter
198198
{
199199
public override SyntaxNode? VisitRecordDeclaration(RecordDeclarationSyntax node)
200200
{
@@ -282,8 +282,20 @@ class TemplateRewriter(string tself, string tid) : CSharpSyntaxRewriter
282282
return IdentifierName(tself)
283283
.WithLeadingTrivia(node.Identifier.LeadingTrivia)
284284
.WithTrailingTrivia(node.Identifier.TrailingTrivia);
285-
else if (node.Identifier.Text == "TId" || node.Identifier.Text == "TValue")
286-
return IdentifierName(tid)
285+
286+
if (node.Identifier.Text.StartsWith("TSelf_"))
287+
return IdentifierName(node.Identifier.Text.Replace("TSelf_", tvalue.Replace('.', '_') + "_"))
288+
.WithLeadingTrivia(node.Identifier.LeadingTrivia)
289+
.WithTrailingTrivia(node.Identifier.TrailingTrivia);
290+
291+
// TODO: remove TId as it's legacy
292+
if (node.Identifier.Text == "TId" || node.Identifier.Text == "TValue")
293+
return IdentifierName(tvalue)
294+
.WithLeadingTrivia(node.Identifier.LeadingTrivia)
295+
.WithTrailingTrivia(node.Identifier.TrailingTrivia);
296+
297+
if (node.Identifier.Text.StartsWith("TValue_"))
298+
return IdentifierName(node.Identifier.Text.Replace("TValue_", tvalue.Replace('.', '_') + "_"))
287299
.WithLeadingTrivia(node.Identifier.LeadingTrivia)
288300
.WithTrailingTrivia(node.Identifier.TrailingTrivia);
289301

@@ -297,8 +309,20 @@ public override SyntaxToken VisitToken(SyntaxToken token)
297309
return Identifier(tself)
298310
.WithLeadingTrivia(token.LeadingTrivia)
299311
.WithTrailingTrivia(token.TrailingTrivia);
300-
else if (token.IsKind(SyntaxKind.IdentifierToken) && (token.Text == "TId" || token.Text == "TValue"))
301-
return Identifier(tid)
312+
313+
if (token.IsKind(SyntaxKind.IdentifierToken) && token.Text.StartsWith("TSelf_"))
314+
return Identifier(token.Text.Replace("TSelf_", tvalue.Replace('.', '_') + "_"))
315+
.WithLeadingTrivia(token.LeadingTrivia)
316+
.WithTrailingTrivia(token.TrailingTrivia);
317+
318+
// TODO: remove TId as it's legacy
319+
if (token.IsKind(SyntaxKind.IdentifierToken) && (token.Text == "TId" || token.Text == "TValue"))
320+
return Identifier(tvalue)
321+
.WithLeadingTrivia(token.LeadingTrivia)
322+
.WithTrailingTrivia(token.TrailingTrivia);
323+
324+
if (token.IsKind(SyntaxKind.IdentifierToken) && token.Text.StartsWith("TValue_"))
325+
return Identifier(token.Text.Replace("TValue_", tvalue.Replace('.', '_') + "_"))
302326
.WithLeadingTrivia(token.LeadingTrivia)
303327
.WithTrailingTrivia(token.TrailingTrivia);
304328

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,32 @@
1-
using Microsoft.CodeAnalysis;
1+
using System.Linq;
2+
using Microsoft.CodeAnalysis;
23

34
namespace StructId;
45

56
/// <summary>
67
/// Provides access to some common types and properties used in the compilation.
78
/// </summary>
89
/// <param name="Compilation">The compilation used to resolve the known types.</param>
9-
/// <param name="StructIdNamespace">The namespace for StructId types.</param>
10-
public record KnownTypes(Compilation Compilation, string StructIdNamespace)
10+
public record KnownTypes(Compilation Compilation)
1111
{
12+
public string StructIdNamespace => IStructId?.ContainingNamespace.ToFullName() ?? "StructId";
13+
1214
/// <summary>
1315
/// System.String
1416
/// </summary>
1517
public INamedTypeSymbol String { get; } = Compilation.GetTypeByMetadataName("System.String")!;
18+
1619
/// <summary>
1720
/// StructId.IStructId
1821
/// </summary>
19-
public INamedTypeSymbol? IStructId { get; } = Compilation.GetTypeByMetadataName($"{StructIdNamespace}.IStructId");
22+
public INamedTypeSymbol? IStructId { get; } = Compilation
23+
.GetAllTypes(true)
24+
.FirstOrDefault(x => x.MetadataName == "IStructId" && x.IsGeneratedByStructId());
25+
2026
/// <summary>
2127
/// StructId.IStructId{T}
2228
/// </summary>
23-
public INamedTypeSymbol? IStructIdT { get; } = Compilation.GetTypeByMetadataName($"{StructIdNamespace}.IStructId`1");
24-
/// <summary>
25-
/// StructId.TStructIdAttribute
26-
/// </summary>
27-
public INamedTypeSymbol? TStructId { get; } = Compilation.GetTypeByMetadataName($"{StructIdNamespace}.TStructIdAttribute");
29+
public INamedTypeSymbol? IStructIdT { get; } = Compilation
30+
.GetAllTypes(true)
31+
.FirstOrDefault(x => x.MetadataName == "IStructId`1" && x.IsGeneratedByStructId());
2832
}

src/StructId.Analyzer/NewtonsoftJsonGenerator.cs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,21 @@ public override void Initialize(IncrementalGeneratorInitializationContext contex
1515
{
1616
base.Initialize(context);
1717

18+
var source = context.CompilationProvider
19+
.Select((x, _) => (new KnownTypes(x), x.GetTypeByMetadataName("Newtonsoft.Json.JsonConverter`1")));
20+
1821
context.RegisterSourceOutput(
19-
context.CompilationProvider
20-
.Select((x, _) => x.GetTypeByMetadataName("Newtonsoft.Json.JsonConverter`1"))
21-
.Combine(context.AnalyzerConfigOptionsProvider.GetStructIdNamespace()),
22+
source,
2223
(context, source) =>
2324
{
24-
if (source.Left == null)
25+
(var known, var converter) = source;
26+
if (converter == null)
2527
return;
2628

2729
context.AddSource("NewtonsoftJsonConverter.cs", SourceText.From(
2830
ThisAssembly.Resources.Templates.NewtonsoftJsonConverter_1.Text
29-
.Replace("namespace StructId;", $"namespace {source.Right};")
30-
.Replace("using StructId;", $"using {source.Right};"),
31+
.Replace("namespace StructId;", $"namespace {known.StructIdNamespace};")
32+
.Replace("using StructId;", $"using {known.StructIdNamespace};"),
3133
Encoding.UTF8));
3234
});
3335
}

src/StructId.Analyzer/RecordAnalyzer.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ public override void Initialize(AnalysisContext context)
3030

3131
static void Analyze(SyntaxNodeAnalysisContext context)
3232
{
33-
var ns = context.Options.AnalyzerConfigOptionsProvider.GlobalOptions.GetStructIdNamespace();
33+
var known = new KnownTypes(context.Compilation);
3434

3535
if (context.Node is not TypeDeclarationSyntax typeDeclaration ||
36-
context.Compilation.GetTypeByMetadataName($"{ns}.IStructId`1") is not { } structIdTypeOfT ||
37-
context.Compilation.GetTypeByMetadataName($"{ns}.IStructId") is not { } structIdType)
36+
known.IStructIdT is not { } structIdTypeOfT ||
37+
known.IStructId is not { } structIdType)
3838
return;
3939

4040
var symbol = context.SemanticModel.GetDeclaredSymbol(typeDeclaration);

src/StructId.Analyzer/TemplateAnalyzer.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ public override void Initialize(AnalysisContext context)
3030

3131
static void Analyze(SyntaxNodeAnalysisContext context)
3232
{
33-
var ns = context.Options.AnalyzerConfigOptionsProvider.GlobalOptions.GetStructIdNamespace();
34-
3533
if (context.Node is not TypeDeclarationSyntax typeDeclaration ||
3634
!typeDeclaration.AttributeLists.Any(list => list.Attributes.Any(attr => attr.IsStructIdTemplate())))
3735
return;

src/StructId.Analyzer/TemplatedGenerator.cs

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,8 @@ public bool AppliesTo(INamedTypeSymbol valueType)
8080

8181
public void Initialize(IncrementalGeneratorInitializationContext context)
8282
{
83-
var structIdNamespace = context.AnalyzerConfigOptionsProvider.GetStructIdNamespace();
84-
8583
var known = context.CompilationProvider
86-
.Combine(structIdNamespace)
87-
.Select((x, _) => new KnownTypes(x.Left, x.Right));
84+
.Select((x, _) => new KnownTypes(x));
8885

8986
var templates = context.CompilationProvider
9087
.SelectMany((x, _) => x.GetAllTypes(includeReferenced: true).OfType<INamedTypeSymbol>())
@@ -99,38 +96,39 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
9996
.Combine(known)
10097
.Select((x, cancellation) =>
10198
{
102-
var (structId, known) = x;
99+
var (tself, known) = x;
103100
// We infer the idType from the required primary constructor Value parameter type
104-
var idType = (INamedTypeSymbol)structId.GetMembers().OfType<IPropertySymbol>().First(p => p.Name == "Value").Type;
105-
var attribute = structId.GetAttributes().First(a => a.AttributeClass != null && a.AttributeClass.Is(known.TStructId));
101+
var tvalue = (INamedTypeSymbol)tself.GetMembers().OfType<IPropertySymbol>().First(p => p.Name == "Value").Type;
102+
var attribute = tself.GetAttributes().First(a => a.IsStructIdTemplate());
106103

107104
// The id type isn't declared in the same file, so we don't do anything fancy with it.
108-
if (idType.DeclaringSyntaxReferences.Length == 0)
109-
return new Template(structId, idType, attribute, known);
105+
if (tvalue.DeclaringSyntaxReferences.Length == 0)
106+
return new Template(tself, tvalue, attribute, known);
110107

111108
// Otherwise, the idType is a file-local type with a single interface
112-
var type = idType.DeclaringSyntaxReferences[0].GetSyntax(cancellation) as TypeDeclarationSyntax;
109+
var type = tvalue.DeclaringSyntaxReferences[0].GetSyntax(cancellation) as TypeDeclarationSyntax;
113110
var iface = type?.BaseList?.Types.FirstOrDefault()?.Type;
114111
if (type == null || iface == null)
115-
return new Template(structId, idType, attribute, known) { OriginalTValue = idType };
112+
return new Template(tself, tvalue, attribute, known) { OriginalTValue = tvalue };
116113

117114
if (x.Right.Compilation.GetSemanticModel(type.SyntaxTree).GetSymbolInfo(iface).Symbol is not INamedTypeSymbol ifaceType)
118-
return new Template(structId, idType, attribute, known);
115+
return new Template(tself, tvalue, attribute, known);
119116

120117
// if the interface is a generic type with a single type argument that is the same as the idType
121118
// make it an unbound generic type. We'll bind it to the actual idType later at template render time.
122-
if (ifaceType.IsGenericType && ifaceType.TypeArguments.Length == 1 && ifaceType.TypeArguments[0].Equals(idType, SymbolEqualityComparer.Default))
119+
if (ifaceType.IsGenericType && ifaceType.TypeArguments.Length == 1 && ifaceType.TypeArguments[0].Equals(tvalue, SymbolEqualityComparer.Default))
123120
ifaceType = ifaceType.ConstructUnboundGenericType();
124121

125-
return new Template(structId, ifaceType, attribute, known)
122+
return new Template(tself, ifaceType, attribute, known)
126123
{
127-
OriginalTValue = idType
124+
OriginalTValue = tvalue
128125
};
129126
})
130127
.Collect();
131128

132129
var ids = context.CompilationProvider
133-
.SelectMany((x, _) => x.Assembly.GetAllTypes().OfType<INamedTypeSymbol>())
130+
.SelectMany((x, _) => x.Assembly.GetAllTypes().OfType<INamedTypeSymbol>()
131+
.Where(t => !t.IsValueTemplate() && !t.IsStructIdTemplate()))
134132
.Where(x => x.IsRecord && x.IsValueType && x.IsPartial())
135133
.Combine(known)
136134
.Where(x => x.Left.Is(x.Right.IStructId) || x.Left.Is(x.Right.IStructIdT))

src/StructId.Analyzer/TemplatizedTValueExtensions.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,8 @@ static class TemplatizedTValueExtensions
8989
/// </summary>
9090
public static IncrementalValuesProvider<TemplatizedTValue> SelectTemplatizedValues(this IncrementalGeneratorInitializationContext context)
9191
{
92-
var structIdNamespace = context.AnalyzerConfigOptionsProvider.GetStructIdNamespace();
93-
9492
var known = context.CompilationProvider
95-
.Combine(structIdNamespace)
96-
.Select((x, _) => new KnownTypes(x.Left, x.Right));
93+
.Select((x, _) => new KnownTypes(x));
9794

9895
var templates = context.CompilationProvider
9996
.SelectMany((x, _) => x.GetAllTypes(includeReferenced: true).OfType<INamedTypeSymbol>())
Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
11
<Project>
22

3-
<ItemGroup>
4-
<CompilerVisibleProperty Include="StructIdNamespace" />
5-
</ItemGroup>
6-
73
</Project>

0 commit comments

Comments
 (0)