1 module d.llvm.local;
2 
3 import d.llvm.codegen;
4 
5 import d.ir.dscope;
6 import d.ir.symbol;
7 import d.ir.type;
8 
9 import llvm.c.core;
10 
11 alias LocalPass = LocalGen*;
12 
13 enum Mode {
14 	Lazy,
15 	Eager,
16 }
17 
18 struct Closure {
19 private:
20 	uint[Variable] indices;
21 	LLVMTypeRef type;
22 }
23 
24 struct LocalData {
25 private:
26 	Closure[][Aggregate] embededContexts;
27 }
28 
29 struct LocalGen {
30 	CodeGen pass;
31 	alias pass this;
32 	
33 	LLVMBuilderRef builder;
34 	
35 	Mode mode;
36 	
37 	LLVMValueRef ctxPtr;
38 	
39 	LLVMValueRef[ValueSymbol] locals;
40 	
41 	Closure[] contexts;
42 	
43 	LLVMValueRef lpContext;
44 	LLVMBasicBlockRef lpBB;
45 	
46 	this(CodeGen pass, Mode mode = Mode.Lazy, Closure[] contexts = []) {
47 		this.pass = pass;
48 		this.mode = mode;
49 		this.contexts = contexts;
50 		
51 		// Make sure locals are initialized.
52 		locals[null] = null;
53 		locals.remove(null);
54 		
55 		// Make sure we alays have a builder ready to rock.
56 		builder = LLVMCreateBuilderInContext(llvmCtx);
57 	}
58 	
59 	~this() {
60 		LLVMDisposeBuilder(builder);
61 	}
62 	
63 	@disable this(this);
64 	
65 	void define(Symbol s) {
66 		if (auto v = cast(Variable) s) {
67 			define(v);
68 		} else if (auto f = cast(Function) s) {
69 			define(f);
70 		} else if (auto a = cast(Aggregate) s) {
71 			define(a);
72 		} else {
73 			import d.llvm.global;
74 			GlobalGen(pass, mode).define(s);
75 		}
76 	}
77 	
78 	void require(Function f) {
79 		if (f.step == Step.Processed) {
80 			return;
81 		}
82 		
83 		LLVMValueRef[] unreachables;
84 		auto backupCurrentBlock = LLVMGetInsertBlock(builder);
85 		scope(exit) {
86 			foreach(u; unreachables) {
87 				LLVMInstructionEraseFromParent(u);
88 			}
89 			
90 			LLVMPositionBuilderAtEnd(builder, backupCurrentBlock);
91 		}
92 		
93 		// OK we need to require. We need to put the module in a good state.
94 		for (
95 			auto fun = LLVMGetFirstFunction(dmodule);
96 			fun !is null;
97 			fun = LLVMGetNextFunction(fun)
98 		) {
99 			for (
100 				auto bb = LLVMGetFirstBasicBlock(fun);
101 				bb !is null;
102 				bb = LLVMGetNextBasicBlock(bb)
103 			) {
104 				if (!LLVMGetBasicBlockTerminator(bb)) {
105 					LLVMPositionBuilderAtEnd(builder, bb);
106 					unreachables ~= LLVMBuildUnreachable(builder);
107 				}
108 			}
109 		}
110 		
111 		scheduler.require(f);
112 	}
113 	
114 	LLVMValueRef declare(Function f) {
115 		require(f);
116 		
117 		// XXX: This should probably a member of the Function class.
118 		auto isLocal = f.hasContext || (cast(NestedScope) f.getParentScope());
119 		auto lookup = isLocal ? locals : globals;
120 		
121 		// FIXME: This is broken, but we do it all in globals for now.
122 		// We have no good way to pas the nested locals down in aggregates
123 		// declarations as we do a round trip through globals.
124 		// We could fix this by removing any require from the backend
125 		// and moving local to the localData, or bubbling down part of the
126 		// aggregate declaration code in the LocalGen. This last option seems
127 		// more reasonable as the situation is also broken for embededContexts.
128 		// In the meantime, just store everything in globals.
129 		lookup = globals;
130 		
131 		auto fun = lookup.get(f, {
132 			auto name = f.mangle.toStringz(pass.context);
133 			
134 			import d.llvm.type;
135 			auto type = LLVMGetElementType(TypeGen(pass).visit(f.type));
136 			
137 			// The method may have been defined when visiting the type.
138 			if (auto funPtr = f in lookup) {
139 				return *funPtr;
140 			}
141 			
142 			// Sanity check: do not declare multiple time.
143 			assert(
144 				!LLVMGetNamedFunction(pass.dmodule, name),
145 				f.mangle.toString(pass.context) ~ " is already declared.",
146 			);
147 			
148 			return lookup[f] = LLVMAddFunction(pass.dmodule, name, type);
149 		} ());
150 		
151 		if (isLocal || f.inTemplate || mode == Mode.Eager) {
152 			if (f.fbody && maybeDefine(f, fun)) {
153 				LLVMSetLinkage(fun, LLVMLinkage.LinkOnceODR);
154 			}
155 		}
156 		
157 		return fun;
158 	}
159 	
160 	LLVMValueRef define(Function f) {
161 		auto fun = declare(f);
162 		if (!f.fbody && !f.intrinsicID) {
163 			return fun;
164 		}
165 		
166 		if (maybeDefine(f, fun)) {
167 			return fun;
168 		}
169 		
170 		auto linkage = LLVMGetLinkage(fun);
171 		assert(
172 			linkage == LLVMLinkage.LinkOnceODR,
173 			"function " ~ f.mangle.toString(context) ~ " already defined",
174 		);
175 		
176 		LLVMSetLinkage(fun, LLVMLinkage.External);
177 		return fun;
178 	}
179 	
180 	private bool maybeDefine(Function f, LLVMValueRef fun) in {
181 		assert(f.step == Step.Processed, "f is not processed");
182 	} do {
183 		auto countBB = LLVMCountBasicBlocks(fun);
184 		if (countBB) {
185 			return false;
186 		}
187 		
188 		auto contexts = f.hasContext ? this.contexts : [];
189 		LocalGen(pass, mode, contexts).genBody(f, fun);
190 		
191 		return true;
192 	}
193 	
194 	private void genBody(Function f, LLVMValueRef fun) in {
195 		assert(
196 			LLVMCountBasicBlocks(fun) == 0,
197 			f.mangle.toString(context) ~ " body is already defined"
198 		);
199 		
200 		assert(f.step == Step.Processed, "f is not processed");
201 		assert(f.fbody || f.intrinsicID, "f must have a body");
202 	} do {
203 		scope(failure) f.dump(context);
204 		
205 		// Alloca and instruction block.
206 		auto allocaBB = LLVMAppendBasicBlockInContext(llvmCtx, fun, "");
207 		
208 		// Handle parameters in the alloca block.
209 		LLVMPositionBuilderAtEnd(builder, allocaBB);
210 		
211 		auto funType = LLVMGetElementType(LLVMTypeOf(fun));
212 		
213 		LLVMValueRef[] params;
214 		LLVMTypeRef[] paramTypes;
215 		params.length = paramTypes.length = LLVMCountParamTypes(funType);
216 		LLVMGetParams(fun, params.ptr);
217 		LLVMGetParamTypes(funType, paramTypes.ptr);
218 		
219 		// If this function is a known intrinsic, swap implementation.
220 		if (f.intrinsicID) {
221 			import d.llvm.expression, d.llvm.intrinsic;
222 			LLVMBuildRet(
223 				builder,
224 				ExpressionGen(&this).buildBitCast(
225 					IntrinsicGen(&this).build(f.intrinsicID, params),
226 					LLVMGetReturnType(funType),
227 				),
228 			);
229 			return;
230 		}
231 		
232 		auto parameters = f.params;
233 		
234 		import d.llvm.type;
235 		auto closure = Closure(f.closure, TypeGen(pass).visit(f));
236 		if (f.hasContext) {
237 			auto parentCtxType = f.type.parameters[0];
238 			assert(parentCtxType.isRef || parentCtxType.isFinal);
239 			
240 			auto parentCtx = params[0];
241 			LLVMSetValueName(parentCtx, "__ctx");
242 			
243 			// Find the right context as parent.
244 			import d.llvm.type;
245 			auto ctxTypeGen = TypeGen(pass).visit(parentCtxType.getType());
246 			
247 			import std.algorithm, std.range;
248 			auto ctxCount = contexts.length -
249 				retro(contexts).countUntil!(c => c.type is ctxTypeGen)();
250 			contexts = contexts[0 .. ctxCount];
251 			
252 			buildCapturedVariables(parentCtx, contexts, f.getCaptures());
253 			
254 			// Chain closures.
255 			ctxPtr = LLVMBuildAlloca(builder, closure.type, "");
256 			
257 			auto ctxStorage = LLVMBuildStructGEP(builder, ctxPtr, 0, "");
258 			LLVMBuildStore(builder, parentCtx, ctxStorage);
259 			contexts ~= closure;
260 		} else {
261 			// Build closure for this function.
262 			import d.llvm.type;
263 			closure.type = TypeGen(pass).visit(f);
264 			contexts = [closure];
265 		}
266 		
267 		params = params[f.hasContext .. $];
268 		paramTypes = paramTypes[f.hasContext .. $];
269 		
270 		foreach(i, p; parameters) {
271 			auto value = params[i];
272 			
273 			auto ptr = createVariableStorage(p, value);
274 			if (!p.isRef && !p.isFinal) {
275 				import std.string;
276 				LLVMSetValueName(
277 					value,
278 					toStringz("arg." ~ p.name.toString(context)),
279 				);
280 			}
281 			
282 			// this is kind of magic :)
283 			import source.name;
284 			if (p.name == BuiltinName!"this") {
285 				buildEmbededCaptures(ptr, p.type);
286 			}
287 		}
288 		
289 		// Generate function's body.
290 		import d.llvm.statement;
291 		StatementGen(&this).visit(f.fbody);
292 		
293 		// If we have a context, let's make it the right size.
294 		if (ctxPtr !is null) {
295 			auto ctxAlloca = ctxPtr;
296 			while(LLVMGetInstructionOpcode(ctxAlloca) != LLVMOpcode.Alloca) {
297 				assert(LLVMGetInstructionOpcode(ctxAlloca) == LLVMOpcode.BitCast);
298 				ctxAlloca = LLVMGetOperand(ctxAlloca, 0);
299 			}
300 			
301 			LLVMPositionBuilderBefore(builder, ctxAlloca);
302 			
303 			auto ctxType = contexts[$ - 1].type;
304 			
305 			import d.llvm.expression;
306 			auto alloc = ExpressionGen(&this).buildCall(
307 				declare(pass.object.getGCThreadLocalAllow()),
308 				[LLVMSizeOf(ctxType)],
309 			);
310 			
311 			// XXX: This should be set on the alloc function instead of the callsite.
312 			LLVMAddCallSiteAttribute(
313 				alloc,
314 				LLVMAttributeReturnIndex,
315 				getAttribute("noalias"),
316 			);
317 			
318 			LLVMReplaceAllUsesWith(ctxAlloca, LLVMBuildPointerCast(
319 				builder,
320 				alloc,
321 				LLVMPointerType(ctxType, 0),
322 				"",
323 			));
324 		}
325 	}
326 	
327 	private void buildEmbededCaptures(LLVMValueRef thisPtr, Type t) {
328 		if (t.kind == TypeKind.Struct) {
329 			auto s = t.dstruct;
330 			if (!s.hasContext) {
331 				return;
332 			}
333 			
334 			buildEmbededCaptures(thisPtr, s, 0);
335 		} else if (t.kind == TypeKind.Class) {
336 			auto c = t.dclass;
337 			if (!c.hasContext) {
338 				return;
339 			}
340 			
341 			import source.name, std.algorithm, std.range;
342 			auto f = retro(c.members)
343 				.filter!(m => m.name == BuiltinName!"__ctx")
344 				.map!(m => cast(Field) m)
345 				.front;
346 			
347 			buildEmbededCaptures(thisPtr, c, f.index);
348 		} else {
349 			assert(0, typeid(t).toString() ~ " is not supported.");
350 		}
351 	}
352 	
353 	private void buildEmbededCaptures(S)(
354 		LLVMValueRef thisPtr,
355 		S s,
356 		uint i,
357 	) if (is(S : Scope)) {
358 		buildCapturedVariables(LLVMBuildLoad(
359 			builder,
360 			LLVMBuildStructGEP(builder, thisPtr, i, ""),
361 			"",
362 		), localData.embededContexts[s], s.getCaptures());
363 	}
364 	
365 	private void buildCapturedVariables(
366 		LLVMValueRef root,
367 		Closure[] contexts,
368 		bool[Variable] capture,
369 	) {
370 		auto closureCount = capture.length;
371 		
372 		// Try to find out if we have the variable in a closure.
373 		foreach_reverse(closure; contexts) {
374 			if (!closureCount) {
375 				break;
376 			}
377 			
378 			// Create enclosed variables.
379 			foreach(v; capture.byKey()) {
380 				if (auto indexPtr = v in closure.indices) {
381 					// Register the variable.
382 					auto var = LLVMBuildStructGEP(
383 						builder,
384 						root,
385 						*indexPtr,
386 						"",
387 					);
388 					
389 					if (v.isRef || v.isFinal) {
390 						var = LLVMBuildLoad(builder, var, "");
391 					}
392 					
393 					LLVMSetValueName(var, v.mangle.toStringz(context));
394 					locals[v] = var;
395 					
396 					assert(closureCount > 0, "closureCount is 0 or lower.");
397 					closureCount--;
398 				}
399 			}
400 			
401 			auto rootPtr = LLVMBuildStructGEP(builder, root, 0, "");
402 			root = LLVMBuildLoad(builder, rootPtr, "");
403 		}
404 		
405 		assert(closureCount == 0);
406 	}
407 	
408 	LLVMValueRef declare(Variable v) {
409 		if (v.storage.isGlobal) {
410 			import d.llvm.global;
411 			return GlobalGen(pass, mode).declare(v);
412 		}
413 		
414 		// TODO: Actually just declare here :)
415 		return locals.get(v, define(v));
416 	}
417 	
418 	LLVMValueRef define(Variable v) in {
419 		assert(!v.isFinal);
420 	} do {
421 		if (v.storage.isGlobal) {
422 			import d.llvm.global;
423 			return GlobalGen(pass, mode).define(v);
424 		}
425 		
426 		import d.llvm.expression;
427 		auto value = v.isRef
428 			? AddressOfGen(&this).visit(v.value)
429 			: ExpressionGen(&this).visit(v.value);
430 		
431 		return createVariableStorage(v, value);
432 	}
433 	
434 	private LLVMValueRef createVariableStorage(
435 		Variable v,
436 		LLVMValueRef value,
437 	) in {
438 		assert(v.storage.isLocal, "globals not supported");
439 	} do {
440 		auto name = v.name.toStringz(context);
441 		
442 		if (v.isRef || v.isFinal) {
443 			if (v.storage == Storage.Capture) {
444 				auto addr = createCaptureStorage(v, "");
445 				LLVMBuildStore(builder, value, addr);
446 			}
447 			
448 			if (LLVMGetValueName(value)[0] == '\0') {
449 				LLVMSetValueName(value, name);
450 			}
451 			
452 			return locals[v] = value;
453 		}
454 		
455 		// Backup current block
456 		auto backupCurrentBlock = LLVMGetInsertBlock(builder);
457 		LLVMPositionBuilderAtEnd(builder, LLVMGetFirstBasicBlock(
458 			LLVMGetBasicBlockParent(backupCurrentBlock),
459 		));
460 		
461 		// Sanity check
462 		scope(success) {
463 			assert(LLVMGetInsertBlock(builder) is backupCurrentBlock);
464 		}
465 		
466 		import d.llvm.type;
467 		LLVMValueRef addr = (v.storage == Storage.Capture)
468 			? createCaptureStorage(v, name)
469 			: LLVMBuildAlloca(builder, TypeGen(pass).visit(v.type), name);
470 		
471 		// Store the initial value into the alloca.
472 		LLVMPositionBuilderAtEnd(builder, backupCurrentBlock);
473 		LLVMBuildStore(builder, value, addr);
474 		
475 		// Register the variable.
476 		return locals[v] = addr;
477 	}
478 	
479 	LLVMValueRef createCaptureStorage(Variable v, const char* name) in {
480 		assert(v.storage == Storage.Capture, "Expected captured");
481 	} do {
482 		auto closure = &contexts[$ - 1];
483 		
484 		// If we don't have a closure, make one.
485 		if (ctxPtr is null) {
486 			ctxPtr = LLVMBuildAlloca(builder, closure.type, "");
487 		}
488 		
489 		return LLVMBuildStructGEP(
490 			builder,
491 			ctxPtr,
492 			closure.indices[v],
493 			name,
494 		);
495 	}
496 	
497 	LLVMValueRef getContext(Function f) {
498 		import d.llvm.type;
499 		auto type = TypeGen(pass).visit(f);
500 		auto value = ctxPtr;
501 		foreach_reverse(i, c; contexts) {
502 			if (value is null) {
503 				return LLVMConstNull(LLVMPointerType(type, 0));
504 			}
505 			
506 			if (c.type is type) {
507 				auto ptrType = LLVMPointerType(type, 0);
508 				return LLVMBuildPointerCast(builder, value, ptrType, "");
509 			}
510 			
511 			auto ctxPtr = LLVMBuildStructGEP(builder, value, 0, "");
512 			value = LLVMBuildLoad(builder, ctxPtr, "");
513 		}
514 		
515 		assert(0, "No context available.");
516 	}
517 	
518 	LLVMTypeRef define(Aggregate a) {
519 		if (a.hasContext) {
520 			localData.embededContexts[a] = contexts;
521 		}
522 		
523 		import d.llvm.global;
524 		return GlobalGen(pass, mode).define(a);
525 	}
526 }