import json
import os
import random

from fastapi import FastAPI
from pydantic import BaseModel
from starlette.middleware.cors import CORSMiddleware

from db import Neo4j


class SoftwareSearch(BaseModel):
    software1: str
    software2: str


class SoftwareInsert(BaseModel):
    newNodeName: str
    relationType: str
    nodeId: str


driver = Neo4j(os.environ.get("NEO4J_URI"), os.environ.get("NEO4J_USER"), os.environ.get("NEO4J_PASSWORD"))
app = FastAPI()

origins = [
    os.getenv("FRONTEND_URL"),
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get("/")
async def root():
    return {"message": "Hello World"}


@app.get("/hello/{name}")
async def say_hello(name: str):
    return {"message": f"Hello {name}"}


@app.get("/graph")
async def graph():
    results = driver.get_graph_data()
    return json.loads(results[0]['data'])


@app.get("/software")
async def software():
    results = driver.get_softwares()
    return {"softwares": json.loads(json.dumps(results))}


@app.get("/software-graph")
async def software():
    nodes = driver.get_nodes()
    jnodes = json.loads(json.dumps(nodes))
    relations = driver.get_relations()
    jrelations = json.loads(json.dumps(relations))

    jsondata = {"nodes": jnodes, "edges": jrelations}

    def findNodeIdWithName(nodes, name):
        for node in nodes:
            if node["name"] == name:
                return node["key"]
        return 0

    for idx, relation in enumerate(jsondata["edges"]):
        relation['key'] = str(idx)
        relation['attributes'] = {}
        relation['label'] = relation['relation_type'][1]
        relation['source'] = findNodeIdWithName(jsondata["nodes"], relation["source"])
        relation['target'] = findNodeIdWithName(jsondata["nodes"], relation["target"])
        del relation['relation_type']

    for node in jsondata["nodes"]:
        node['attributes'] = {}
        node['attributes']['label'] = node['name']
        node['attributes']['x'] = random.random()
        node['attributes']['y'] = random.random()
        node['attributes']['size'] = 22
        del node['name']

    return jsondata


@app.post("/software-search")
async def software_search(search: SoftwareSearch):
    print(search.software1, search.software2)
    res = driver.get_softwares_shortest_path(search)
    return {"softwares": res}


@app.post("/software-insert")
async def software_insert(insert_data: SoftwareInsert):
    res = driver.insert_software(insert_data)
    return res