1 module d.llvm.type;
2 
3 import d.llvm.codegen;
4 
5 import d.ir.symbol;
6 import d.ir.type;
7 
8 import source.exception;
9 
10 import util.visitor;
11 
12 import llvm.c.core;
13 
14 // Conflict with Interface in object.di
15 alias Interface = d.ir.symbol.Interface;
16 
17 struct TypeGenData {
18 private:
19 	Class classInfoClass;
20 	
21 	LLVMTypeRef[Aggregate] aggTypes;
22 	LLVMValueRef[Aggregate] typeInfos;
23 	
24 	LLVMValueRef[Class] vtbls;
25 	LLVMTypeRef[Function] funCtxTypes;
26 }
27 
28 struct TypeGen {
29 	private CodeGen pass;
30 	alias pass this;
31 	
32 	this(CodeGen pass) {
33 		this.pass = pass;
34 	}
35 	
36 	// XXX: lack of multiple alias this, so we do it automanually.
37 	private {
38 		@property
39 		ref Class classInfoClass() {
40 			return pass.typeGenData.classInfoClass;
41 		}
42 		
43 		@property
44 		ref LLVMTypeRef[Aggregate] typeSymbols() {
45 			return pass.typeGenData.aggTypes;
46 		}
47 		
48 		@property
49 		ref LLVMValueRef[Aggregate] typeInfos() {
50 			return pass.typeGenData.typeInfos;
51 		}
52 		
53 		@property
54 		ref LLVMValueRef[Class] vtbls() {
55 			return pass.typeGenData.vtbls;
56 		}
57 		
58 		@property
59 		ref LLVMTypeRef[Function] funCtxTypes() {
60 			return pass.typeGenData.funCtxTypes;
61 		}
62 	}
63 	
64 	LLVMValueRef getTypeInfo(Aggregate a) {
65 		if (a !in typeInfos) {
66 			this.dispatch(a);
67 		}
68 		
69 		return typeInfos[a];
70 	}
71 	
72 	// XXX: Remove ?
73 	LLVMValueRef getVtbl(Class c) {
74 		return vtbls[c];
75 	}
76 	
77 	LLVMTypeRef visit(Type t) {
78 		return t.getCanonical().accept(this);
79 	}
80 	
81 	LLVMTypeRef buildOpaque(Type t) {
82 		t = t.getCanonical();
83 		switch (t.kind) with(TypeKind) {
84 			case Struct:
85 				return buildOpaque(t.dstruct);
86 			
87 			case Union:
88 				return buildOpaque(t.dunion);
89 			
90 			case Context:
91 				return buildOpaque(t.context);
92 			
93 			default:
94 				return t.accept(this);
95 		}
96 	}
97 	
98 	LLVMTypeRef visit(BuiltinType t) {
99 		final switch(t) with(BuiltinType) {
100 			case None :
101 				assert(0, "Not Implemented");
102 			
103 			case Void :
104 				return LLVMVoidTypeInContext(llvmCtx);
105 			
106 			case Bool :
107 				return LLVMInt1TypeInContext(llvmCtx);
108 			
109 			case Char, Ubyte, Byte :
110 				return LLVMInt8TypeInContext(llvmCtx);
111 			
112 			case Wchar, Ushort, Short :
113 				return LLVMInt16TypeInContext(llvmCtx);
114 			
115 			case Dchar, Uint, Int :
116 				return LLVMInt32TypeInContext(llvmCtx);
117 			
118 			case Ulong, Long :
119 				return LLVMInt64TypeInContext(llvmCtx);
120 			
121 			case Ucent, Cent :
122 				return LLVMInt128TypeInContext(llvmCtx);
123 			
124 			case Float :
125 				return LLVMFloatTypeInContext(llvmCtx);
126 			
127 			case Double :
128 				return LLVMDoubleTypeInContext(llvmCtx);
129 			
130 			case Real :
131 				return LLVMX86FP80TypeInContext(llvmCtx);
132 			
133 			case Null :
134 				return LLVMPointerType(LLVMInt8TypeInContext(llvmCtx), 0);
135 		}
136 	}
137 	
138 	LLVMTypeRef visitPointerOf(Type t) {
139 		auto pointed = (t.kind != TypeKind.Builtin || t.builtin != BuiltinType.Void)
140 			? buildOpaque(t)
141 			: LLVMInt8TypeInContext(llvmCtx);
142 		
143 		return LLVMPointerType(pointed, 0);
144 	}
145 	
146 	LLVMTypeRef visitSliceOf(Type t) {
147 		LLVMTypeRef[2] types;
148 		types[0] = LLVMInt64TypeInContext(llvmCtx);
149 		types[1] = visitPointerOf(t);
150 		
151 		return LLVMStructTypeInContext(llvmCtx, types.ptr, 2, false);
152 	}
153 	
154 	LLVMTypeRef visitArrayOf(uint size, Type t) {
155 		return LLVMArrayType(visit(t), size);
156 	}
157 	
158 	auto buildOpaque(Struct s) {
159 		if (auto st = s in typeSymbols) {
160 			return *st;
161 		}
162 		
163 		return typeSymbols[s] = LLVMStructCreateNamed(
164 			llvmCtx,
165 			s.mangle.toStringz(context),
166 		);
167 	}
168 	
169 	LLVMTypeRef visit(Struct s) in {
170 		assert(s.step >= Step.Signed);
171 	} do {
172 		// FIXME: Ensure we don't have forward references.
173 		auto llvmStruct = buildOpaque(s);
174 		if (!LLVMIsOpaqueStruct(llvmStruct)) {
175 			return llvmStruct;
176 		}
177 		
178 		LLVMTypeRef[] types;
179 		foreach(member; s.members) {
180 			if (auto f = cast(Field) member) {
181 				types ~= visit(f.type);
182 			}
183 		}
184 		
185 		LLVMStructSetBody(llvmStruct, types.ptr, cast(uint) types.length, false);
186 		return llvmStruct;
187 	}
188 	
189 	auto buildOpaque(Union u) {
190 		if (auto ut = u in typeSymbols) {
191 			return *ut;
192 		}
193 		
194 		return typeSymbols[u] = LLVMStructCreateNamed(
195 			llvmCtx,
196 			u.mangle.toStringz(context),
197 		);
198 	}
199 	
200 	LLVMTypeRef visit(Union u) in {
201 		assert(u.step >= Step.Signed);
202 	} do {
203 		// FIXME: Ensure we don't have forward references.
204 		auto llvmStruct = buildOpaque(u);
205 		if (!LLVMIsOpaqueStruct(llvmStruct)) {
206 			return llvmStruct;
207 		}
208 		
209 		auto hasContext = u.hasContext;
210 		auto members = u.members;
211 		assert(!hasContext, "Voldemort union not supported atm");
212 		
213 		LLVMTypeRef[3] types;
214 		uint elementCount = 1 + hasContext;
215 		
216 		uint firstindex, size, dalign;
217 		foreach(i, m; members) {
218 			if (auto f = cast(Field) m) {
219 				types[hasContext] = visit(f.type);
220 				
221 				import llvm.c.target;
222 				size = cast(uint) LLVMStoreSizeOfType(targetData, types[hasContext]);
223 				dalign = cast(uint) LLVMABIAlignmentOfType(targetData, types[hasContext]);
224 				
225 				firstindex = cast(uint) (i + 1);
226 				break;
227 			}
228 		}
229 		
230 		uint extra;
231 		foreach(m; members[firstindex .. $]) {
232 			if (auto f = cast(Field) m) {
233 				auto t = visit(f.type);
234 				
235 				import llvm.c.target;
236 				auto s = cast(uint) LLVMStoreSizeOfType(targetData, t);
237 				auto a = cast(uint) LLVMABIAlignmentOfType(targetData, t);
238 				
239 				extra = ((size + extra) < s) ? s - size : extra;
240 				dalign = (a > dalign) ? a : dalign;
241 			}
242 		}
243 		
244 		if (extra > 0) {
245 			elementCount++;
246 			types[1] = LLVMArrayType(LLVMInt8TypeInContext(llvmCtx), extra);
247 		}
248 		
249 		LLVMStructSetBody(llvmStruct, types.ptr, elementCount, false);
250 		
251 		import llvm.c.target;
252 		assert(
253 			LLVMABIAlignmentOfType(targetData, llvmStruct) == dalign,
254 			"union with differing alignement are not supported."
255 		);
256 		
257 		return llvmStruct;
258 	}
259 	
260 	LLVMTypeRef visit(Class c) {
261 		// Ensure classInfo is built first.
262 		if (!classInfoClass) {
263 			classInfoClass = pass.object.getClassInfo();
264 			
265 			if (c !is classInfoClass) {
266 				visit(classInfoClass);
267 			}
268 		}
269 		
270 		if (auto ct = c in typeSymbols) {
271 			return *ct;
272 		}
273 		
274 		auto mangle = c.mangle.toString(context);
275 		auto llvmStruct = LLVMStructCreateNamed(llvmCtx, mangle.ptr);
276 		auto structPtr = typeSymbols[c] = LLVMPointerType(llvmStruct, 0);
277 		
278 		import std.string;
279 		auto classInfoPtr = visit(classInfoClass);
280 		auto classInfoStruct = LLVMGetElementType(classInfoPtr);
281 		auto vtblStruct = LLVMStructCreateNamed(llvmCtx, toStringz(mangle ~ "__vtbl"));
282 		auto vtblPtr = LLVMPointerType(vtblStruct, 0);
283 		
284 		LLVMTypeRef[2] classDataElts = [classInfoStruct, vtblStruct];
285 		auto classDataStruct = LLVMStructTypeInContext(
286 			llvmCtx,
287 			classDataElts.ptr,
288 			cast(uint) classDataElts.length,
289 			false,
290 		);
291 		
292 		import std.string;
293 		auto classData = LLVMAddGlobal(
294 			dmodule,
295 			classDataStruct,
296 			toStringz(mangle ~ "__Metadata"),
297 		);
298 		
299 		typeInfos[c] = LLVMConstBitCast(classData, classInfoPtr);
300 		
301 		LLVMValueRef[] methods;
302 		LLVMTypeRef[] initTypes = [vtblPtr];
303 		foreach(member; c.members) {
304 			if (auto m = cast(Method) member) {
305 				auto oldBody = m.fbody;
306 				scope(exit) m.fbody = oldBody;
307 				// FIXME: Do whatever is needed here.
308 				// m.fbody = null;
309 				
310 				import d.llvm.global;
311 				methods ~= GlobalGen(pass).declare(m);
312 			} else if (auto f = cast(Field) member) {
313 				if (f.index > 0) {
314 					initTypes ~= visit(f.value.type);
315 				}
316 			}
317 		}
318 		
319 		LLVMStructSetBody(
320 			llvmStruct,
321 			initTypes.ptr,
322 			cast(uint) initTypes.length,
323 			false,
324 		);
325 		
326 		import std.algorithm, std.array;
327 		auto vtblTypes = methods.map!(m => LLVMTypeOf(m)).array();
328 		LLVMStructSetBody(
329 			vtblStruct,
330 			vtblTypes.ptr,
331 			cast(uint) vtblTypes.length,
332 			false,
333 		);
334 		
335 		auto vtbl = LLVMConstNamedStruct(
336 			vtblStruct,
337 			methods.ptr,
338 			cast(uint) methods.length,
339 		);
340 		
341 		auto i32 = LLVMInt32TypeInContext(llvmCtx);
342 		LLVMValueRef[2] indices = [
343 			LLVMConstInt(i32, 0, false),
344 			LLVMConstInt(i32, 1, false),
345 		];
346 		
347 		vtbls[c] = LLVMConstInBoundsGEP(
348 			classData,
349 			indices.ptr,
350 			indices.length,
351 		);
352 		
353 		// Doing it at the end to avoid infinite recursion
354 		// when generating object.ClassInfo
355 		auto base = c.base;
356 		visit(base);
357 		
358 		LLVMValueRef[2] classInfoData = [getVtbl(classInfoClass), getTypeInfo(base)];
359 		auto classInfoGen = LLVMConstNamedStruct(
360 			classInfoStruct,
361 			classInfoData.ptr,
362 			classInfoData.length,
363 		);
364 		
365 		LLVMValueRef[2] classDataData = [classInfoGen, vtbl];
366 		auto classDataGen = LLVMConstNamedStruct(
367 			classDataStruct,
368 			classDataData.ptr,
369 			classDataData.length,
370 		);
371 		
372 		LLVMSetInitializer(classData, classDataGen);
373 		LLVMSetGlobalConstant(classData, true);
374 		LLVMSetLinkage(classData, LLVMLinkage.LinkOnceODR);
375 		
376 		return structPtr;
377 	}
378 	
379 	LLVMTypeRef visit(Enum e) {
380 		return visit(e.type);
381 	}
382 	
383 	LLVMTypeRef visit(TypeAlias a) {
384 		assert(0, "Use getCanonical");
385 	}
386 	
387 	LLVMTypeRef visit(Interface i) {
388 		if (auto it = i in typeSymbols) {
389 			return *it;
390 		}
391 		
392 		auto mangle = i.mangle.toString(context);
393 		auto llvmStruct = typeSymbols[i] = LLVMStructCreateNamed(llvmCtx, mangle.ptr);
394 		
395 		import std.string;
396 		auto vtblStruct = LLVMStructCreateNamed(llvmCtx, toStringz(mangle ~ "__vtbl"));
397 		LLVMTypeRef[2] elements;
398 		elements[0] = visit(pass.object.getObject());
399 		elements[1] = LLVMPointerType(vtblStruct, 0);
400 		LLVMStructSetBody(llvmStruct, elements.ptr, elements.length, false);
401 		return llvmStruct;
402 	}
403 	
404 	auto buildOpaque(Function f) {
405 		if (auto fctx = f in funCtxTypes) {
406 			return *fctx;
407 		}
408 		
409 		import std.string;
410 		return funCtxTypes[f] = LLVMStructCreateNamed(
411 			llvmCtx,
412 			toStringz("S" ~ f.name.toString(context) ~ ".ctx"),
413 		);
414 	}
415 	
416 	LLVMTypeRef visit(Function f) in {
417 		assert(
418 			f.step >= Step.Processed,
419 			f.name.toString(pass.context) ~ " isn't signed",
420 		);
421 	} do {
422 		auto ctxStruct = buildOpaque(f);
423 		if (!LLVMIsOpaqueStruct(ctxStruct)) {
424 			return ctxStruct;
425 		}
426 		
427 		auto count = cast(uint) f.closure.length + f.hasContext;
428 		
429 		LLVMTypeRef[] ctxElts;
430 		ctxElts.length = count;
431 		
432 		if (f.hasContext) {
433 			auto parentCtxType = f.type.parameters[0].getType();
434 			ctxElts[0] = LLVMPointerType(visit(parentCtxType), 0);
435 		}
436 		
437 		foreach(v, i; f.closure) {
438 			ctxElts[i] = visit(v.type);
439 		}
440 		
441 		LLVMStructSetBody(ctxStruct, ctxElts.ptr, count, false);
442 		return ctxStruct;
443 	}
444 	
445 	private auto buildParamType(ParamType pt) {
446 		auto t = visit(pt.getType());
447 		if (pt.isRef) {
448 			t = LLVMPointerType(t, 0);
449 		}
450 		
451 		return t;
452 	}
453 	
454 	LLVMTypeRef visit(FunctionType f) {
455 		import std.algorithm, std.array;
456 		auto params = f.getFunction().parameters.map!(p => buildParamType(p)).array();
457 		auto fun = LLVMPointerType(LLVMFunctionType(
458 			buildParamType(f.returnType),
459 			params.ptr,
460 			cast(uint) params.length,
461 			f.isVariadic,
462 		), 0);
463 		
464 		auto contexts = f.contexts;
465 		if (contexts.length == 0) {
466 			return fun;
467 		}
468 		
469 		auto length = cast(uint) contexts.length;
470 		
471 		LLVMTypeRef[] types;
472 		types.length = length + 1;
473 		
474 		foreach(i, _; contexts) {
475 			types[i] = params[i];
476 		}
477 		
478 		types[length] = fun;
479 		return LLVMStructTypeInContext(llvmCtx, types.ptr, length + 1, false);
480 	}
481 	
482 	LLVMTypeRef visit(Type[] seq) {
483 		import std.algorithm, std.array;
484 		auto types = seq.map!(t => visit(t)).array();
485 		return LLVMStructTypeInContext(llvmCtx, types.ptr, cast(uint) types.length, false);
486 	}
487 	
488 	LLVMTypeRef visit(Pattern p) {
489 		assert(0, "Patterns cannot be generated.");
490 	}
491 	
492 	import d.ir.error;
493 	LLVMTypeRef visit(CompileError e) {
494 		assert(0, "Error type can't be generated.");
495 	}
496 }