// g++ llvmtest.cpp `llvm-config --cxxflags --ldflags --libs core engine`
#include "llvm/Module.h"
#include "llvm/Function.h"
#include "llvm/PassManager.h"
#include "llvm/CallingConv.h"
#include "llvm/Analysis/Verifier.h"
#include "llvm/Assembly/PrintModulePass.h"
#include "llvm/Support/IRBuilder.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"


using namespace llvm;

void BuildMulAdd(Module *mod)
{
    Constant* c = mod->getOrInsertFunction("mul_add",
    /*ret type*/                           IntegerType::get(32),
    /*args*/                               IntegerType::get(32),
                                           IntegerType::get(32),
                                           IntegerType::get(32),
    /*varargs terminated with null*/       NULL);
    
    Function* mul_add = cast<Function>(c);
    mul_add->setCallingConv(CallingConv::C);
    
    Function::arg_iterator args = mul_add->arg_begin();
    Value* x = args++;
    x->setName("x");
    Value* y = args++;
    y->setName("y");
    Value* z = args++;
    z->setName("z");
    
    BasicBlock* block = BasicBlock::Create("entry", mul_add);
    IRBuilder<> builder(block);
    
    Value* tmp = builder.CreateBinOp(Instruction::Mul,
                                   x, y, "tmp");
    Value* tmp2 = builder.CreateBinOp(Instruction::Add,
                                    tmp, z, "tmp2");
    
    builder.CreateRet(tmp2);
}

void BuildGCD(Module *mod)
{
    Constant *c = mod->getOrInsertFunction("gcd",
                                           IntegerType::get(32),
                                           IntegerType::get(32),
                                           IntegerType::get(32),
                                           NULL);
    Function *gcd = cast<Function>(c);
    
    Function::arg_iterator args = gcd->arg_begin();
    Value *x = args++;
    x->setName("x");
    Value *y = args++;
    y->setName("y");
    
    BasicBlock *entry = BasicBlock::Create("entry", gcd);
    BasicBlock *ret = BasicBlock::Create("return", gcd);
    BasicBlock *condFalse = BasicBlock::Create("condFalse", gcd);
    BasicBlock *condTrue = BasicBlock::Create("condTrue", gcd);
    BasicBlock *condFalse2 = BasicBlock::Create("condTrue", gcd);
    
    IRBuilder<> builder(entry);
    Value *xeqy = builder.CreateICmpEQ(x, y, "tmp");
    builder.CreateCondBr(xeqy, ret, condFalse);
    
    builder.SetInsertPoint(ret);
    builder.CreateRet(x);
    
    builder.SetInsertPoint(condFalse);
    Value *xlty = builder.CreateICmpULT(x, y, "tmp");
    builder.CreateCondBr(xlty, condTrue, condFalse2);
    
    builder.SetInsertPoint(condTrue);
    Value *yminusx = builder.CreateSub(y, x, "tmp");
    std::vector<Value *> args1;
    args1.push_back(x);
    args1.push_back(yminusx);
    Value *recurCall1 = builder.CreateCall(gcd, args1.begin(), args1.end(), "tmp");
    builder.CreateRet(recurCall1);
    
    builder.SetInsertPoint(condFalse2);
    Value *xminusy = builder.CreateSub(x, y, "tmp");
    std::vector<Value *> args2;
    args2.push_back(xminusy);
    args2.push_back(y);
    Value *recurCall2 = builder.CreateCall(gcd, args2.begin(), args2.end(), "tmp");
    builder.CreateRet(recurCall2);
}

Module* makeLLVMModule()
{
    // Module Construction
    Module* mod = new Module("test");
    
    BuildMulAdd(mod);
    BuildGCD(mod);
    
    return mod;
}

int main(int argc, char**argv)
{
    Module* Mod = makeLLVMModule();
    
    verifyModule(*Mod, PrintMessageAction);
    
    PassManager PM;
    ModulePass *pmp = createPrintModulePass(&outs());
    PM.add(pmp);
    PM.run(*Mod);
    
    ExecutionEngine *engine = ExecutionEngine::create(Mod);
    
    typedef int (*MulAddFptr)(int, int, int);
    MulAddFptr fptr = (MulAddFptr)engine->getPointerToFunction(Mod->getFunction("mul_add"));
    
    typedef int (*GCDFptr)(int, int);
    GCDFptr gcd = (GCDFptr)engine->getPointerToFunction(Mod->getFunction("gcd"));
    
    fprintf(stderr, "%p: 2*3+4 = %d\n", fptr, fptr(2, 3, 4));
    fprintf(stderr, "%p: gcd(10, 25) = %d, gcd(1234, 5678) = %d\n", gcd, gcd(10, 25), gcd(1234, 5678));
    
    delete Mod;
    return 0;
}

