/* Copyright 1996 Acorn Computers Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
/* File:        mem.c
 * Purpose:     memory allocation in the Toolbox
 * Author:      Ian Johnson
 * History:     6-Aug-93: IDJ: created
 */



/*
 * This is a central memory allocation place for the Toolbox.
 * A doubly-linked list of all allocated blocks is kept (maybe remove this
 * if using too much memory!).
 * To enable debugging of memory allocation, the calling functions should be
 * compiled without -ff, so that fn names are in the code area.  Compile
 * this code with -DDEBUG_MEMORY.
 * Memory tracing is turned on by setting the letter 'm' in the central
 * debug string.
 *
 */

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include "kernel.h"
#include "swis.h"

#include "debug.h"

#include "mem.h"
#include "objects.toolbox.h"

#undef mem_allocate
#undef mem_free
#undef mem_print_list
#undef mem_extend
#undef mem_chk

#include "toolboxmem.h"
#define calloc(m,s) toolbox_memory_alloc_c((m)*(s))
#define free(p) toolbox_memory_free_c(p)
#define reallocate(b,by) toolbox_memory_extend_c(b,by)

typedef struct block
{
    struct block *next;       /* next one in the chain */
    struct block *prev;       /* previous one in the chain */
    unsigned int  size;        /* size of allocated block */

#ifdef DEBUG_MEMORY
    int           guard;       /* should be 'TBOX' */
    char         *fn;          /* name of function which did the allocating */
    char         *tag;         /* extra info entered by caller (eg line #) */
#endif

                              /* caller gets a pointer to here */
} mem__Block;

#define GUARD_WORD *(int *)"TBOX"



static mem__Block *mem__block_list = 0;    /* linked list of all allocated blocks */

enum {
        memorysys_USE_TOOLBOX	= 1,
        memorysys_USE_OWN_CODE	= 2
};

static int memory_sys_flag = 0;

/* This function must only ever be called by the Toolbox module and is used to bypass the
 * problem of not being able to call our own SWIs during module initialisation.
 */
extern void mem_i_am_the_toolbox (void)
{
        memory_sys_flag = memorysys_USE_TOOLBOX;
}

extern void mem_init (void)
{
        if (_swix(Toolbox_Memory, _INR(0,2), 2, 0, 0) == NULL) {
                /* Then a suitable Toolbox module *was* present */
                memory_sys_flag = memorysys_USE_TOOLBOX;
        }
        else {
                memory_sys_flag = memorysys_USE_OWN_CODE;
        }
}

static void *(reallocate)(void *p, int by)
{
    /*
     * own implementation of realloc, cos broken in 3.10
     */

     DEBUG debug_output ("memory", "Reallocating %p to %d\n", p, by);

    if (p == 0 || by == 0)  /* not like real realloc! */
        return 0;
    else
    {
        void *newp;

        if (_swix (OS_Module, _IN(0)|_INR(2,3)|_OUT(2),
                              13,  /* extend RMA block */
                              p,
                              by,
                              &newp
                  ) != 0)
        {
           DEBUG debug_output ("memory", "Failed\n");
            return 0;
        }
        else
        {
           DEBUG debug_output ("memory", "Old p %p, new p %p\n", p, newp);
            return newp;
        }
    }

}


static void *toolbox_memory_alloc_c(int s)
{
	if (memory_sys_flag == memorysys_USE_TOOLBOX)
        	return toolbox_memory_alloc(s);
	else if (memory_sys_flag == memorysys_USE_OWN_CODE)
		return (calloc)(s, 1);
	else {
	        mem_init();
	        return toolbox_memory_alloc_c(s);
	}
}

static void toolbox_memory_free_c(void *p)
{
	if (memory_sys_flag == memorysys_USE_TOOLBOX)
        	toolbox_memory_free(p);
	else if (memory_sys_flag == memorysys_USE_OWN_CODE)
		(free)(p);
	else {
	        mem_init();
	        toolbox_memory_free_c(p);
	}
}

static void *toolbox_memory_extend_c(void *b, int by)
{
	if (memory_sys_flag == memorysys_USE_TOOLBOX)
	        return toolbox_memory_extend(b, by);
	else if (memory_sys_flag == memorysys_USE_OWN_CODE)
		return (reallocate)(b, by);
	else {
	        mem_init();
	        return toolbox_memory_extend_c(b, by);
	}
}



static void mem__add_block (mem__Block *b)
{
    /*
     * Function to add an element to the linked list of memory blocks
     * (add at head of list)
     */

    b->next = mem__block_list;
    b->prev = NULL;
    if (mem__block_list != NULL)
       mem__block_list->prev = b;
    mem__block_list = b;
    b->size = 0;

#ifdef DEBUG_MEMORY
    b->fn = 0;
    b->tag = 0;
#endif

}


static void mem__remove_block (mem__Block *b)
{

    /*
     * Function to remove an element from the linked list of memory blocks
     *
     */

    if (b == mem__block_list)
    {
        mem__block_list = b->next;
        if (b->next != NULL)
            b->next->prev = NULL;
    }
    else
    {
        b->prev->next = b->next;
        if (b->next != NULL)
            b->next->prev = b->prev;
    }
}


/********************************* MEMORY DEBUGGING CODE *****************************/
#ifdef DEBUG_MEMORY

static unsigned int mem__total_size (void)
{

    /*
     * Function to return total memory usage
     *
     */

    mem__Block *b = mem__block_list;
    int total = 0;

    while (b != NULL)
    {
        total += b->size;
        b = b->next;
    }

    return total;
}



extern void mem_print_list (void)
{

    /*
     * Function to print out list of currently allocated blocks
     *
     */

    mem__Block *b;

    b = mem__block_list;

    debug_output ("memory", "Memory usage\n");
    debug_output ("memory", "------------\n");
    debug_output ("memory", "Block        size  allocated by                      tag\n");
    debug_output ("memory", "-----        ----  ------------                      ---\n");

    while (b != NULL)
    {
        debug_output ("memory", "%08.8p %08.8d  %-33.33s %s\n",
                      (void *)(b+1),
                      b->size,
                      (b->fn)?b->fn:"<unknown fn>",
                      (b->tag)?b->tag:"<unknown tag>");

        b = b->next;
    }

    debug_output ("memory", "------------------------------\n");
    debug_output ("memory", "Total   %08.8d\n", mem__total_size());
    debug_output ("memory", "------------------------------\n");
}


extern void mem_chk (void *ptr, unsigned int flags, int line, char *file)
{

    /*
     * Types of memory check:
     *
     *   CHECK_HEAP
     *   CHECK_NONZERO
     */

    _kernel_swi_regs regs;
    int              carry;


    if (ptr == 0)
    {
        if (flags & CHECK_NONZERO)
            debug_output ("memory", "Warning: ptr is zero: %d %s\n", line, file);

        return;  /* zero is OK */
    }

    if (flags & CHECK_HEAP)
    {
        /* check the ptr is inside one of our allocated blocks */
        mem__Block *b = mem__block_list;

        while (b != NULL)
        {
            if (ptr <= (void *)((char *)(b+1) + b->size))   /* it's pointing inside this block */
                return;

            b = b->next;
        }

        debug_output ("memory", "Warning: ptr is not in a currently allocated block %d %s\n", line, file);

        return;
    }

    /* finally do an OS_ValidateAddress on the ptr */

    regs.r[0] = (int)ptr;
    regs.r[1] = (int)ptr;
    _kernel_swi_c (OS_ValidateAddress, &regs, &regs, &carry);
    if (carry != 0)
        debug_output ("memory", "Warning: ptr is not a valid address %d %s\n", line, file);

}




/* ------------------------------------ horrible hackery to get calling function name -------------------- */

static unsigned int fpoffz()
{
    return 0;
}
static unsigned int *myfp()
{
    unsigned int local;
    return (unsigned int *)*(&local + 1 + fpoffz());
}

#define CALLER_FP() ((unsigned int *)(*(myfp()-3)))

static char *mem__get_fn_name (unsigned int *fp)
{

    /*
     * Given a pointer to a caller's frame (if any) this function returns
     * its name which appears "before" it in the code area if the code
     * has not been compiled cc -ff.  If no name is found, returns 0
     *
     */

    /* --- search to find magic marker 0xff00XXXX, where XXXX is the length
           of the function name --- */

    int i;
    char *name = 0;
    unsigned int *code_wordp;


    /*
     * check the sanity of the address first!
     */


    if (fp == 0)
        return NULL;
    else
        code_wordp = (unsigned int *)(*fp & 0x03fffffc);

    for (i = 0; i < 10; i++)     /* same as backtrace code in C library (yuk) */
    {
        int w = *--code_wordp;
        if ((w & 0xffff0000) == 0xff000000)
        {
            name = (char *)code_wordp - (w & 0xffff);
            break;
        }
    }

    return name;
}



/* -------------------------------- freeing memory --------------------------------- */

extern void mem_free (void *block, char *msg)
{

    /*
     * Function to free a given block, and output debugging info.
     *
     * We print out who allocated the block (+tag info), and who is
     * doing the free.
     * We also give current total memory usage.
     *
     */

    char *fn_name;
    mem__Block *b = (mem__Block *)block - 1;

    fn_name = mem__get_fn_name (CALLER_FP());

    debug_output ("memory", "Freeing block %p from %s (%s) - [allocated from %s (%s)",
                 (void *)(b+1),
                 (fn_name)?fn_name:"<unknown fn>",
                 (msg)?msg:"<unknown tag>",
                 (b->fn)?b->fn:"<unknown fn>",
                 (b->tag)?b->tag:"<unknown tag>");
    debug_output ("memory", "(%d)] %d", b->size, mem__total_size());

    /* --- check guard word --- */
    if (b->guard != GUARD_WORD)
        debug_output ("memory", "\nInvalid guard word in memory block - block start at %p\n", block);

    /* --- remove it from list of allocated blocks --- */
    mem__remove_block(b);

    /* --- free debug info from block --- */
    if (b->fn)
    {
        free(b->fn);
    }
    if (b->tag)
    {
        free(b->tag);
    }

    /* --- ... and free the block itself --- */
    free(b);

    debug_output ("memory", " => %d\n", mem__total_size());
}


extern void mem_free_all (void)
{
    mem__Block *b = mem__block_list, *next;

    while (b != NULL)
    {
        next = b->next;

        mem__remove_block (b);
        if (b->fn)
            free(b->fn);
        if (b->tag)
            free(b->tag);
        free(b);

        b = next;
    }
}



/* -------------------------------- allocating memory ------------------------------ */



extern void *mem_allocate(unsigned int size, char *tag)
{

    /*
     * Function to allocate "size" bytes of store, with debugging info
     *
     * We store:
     *        - name of fn doing the allocating
     *        - size of requested store
     *        - extra info string
     *        - followed in memory by the "real" block
     *
     * All allocated memory blocks are stored in a linked list, so that
     * we can easily work out how much is allocated at one time, and why.
     *
     * To get fn names caller must not be compiled -ff
     * fn points at calling function, and we go back through the code area
     * to find its textual name (use 0 if none there)
     *
     * Tag is some more info if appropriate
     *
     */

    char *fn_name = mem__get_fn_name(CALLER_FP());
    mem__Block *new_block;

    debug_output ("memory", "Allocating %d bytes in %s (%s) %d",
                  size,
                  (fn_name)?fn_name:"<unknown fn>",
                  (tag)?tag:"<no tag>",
                  mem__total_size(),
                  0);

    /* --- allocate new block to put in list of all allocated blocks --- */
    if ((new_block = (mem__Block *)calloc(sizeof(mem__Block) + size, sizeof(char))) == 0)
        return 0;
    mem__add_block (new_block);

    /* --- mark it for checking --- */
    new_block->guard = GUARD_WORD;

    /* --- note the name of the calling function --- */
    if (fn_name != NULL)
    {
        if ((new_block->fn = calloc(strlen (fn_name) + 1, sizeof(char))) == 0)
        {
            mem_free ((void *)(new_block+1), "no mem for fn name");
            return 0;
        }
        strcpy (new_block->fn, fn_name);
    }

    /* --- note size of requested store --- */
    new_block->size = size;

    /* --- note tag passed by calling function --- */
    if (tag != NULL)
    {
        if ((new_block->tag = calloc(strlen (tag) + 1, sizeof(char))) == 0)
        {
            mem_free ((void *)(new_block+1), "no mem for tag");
            return 0;
        }
        strcpy (new_block->tag, tag);
    }

    /* --- return "real" block pointer to caller --- */
    debug_output ("memory", " => %d [@%p]\n", mem__total_size(), (void *)(new_block+1));

    return (void *)(new_block+1);
}





extern void *mem_extend (void *block, int by)
{

    /*
     * Function to extend an existing block of memory, with debugging info
     *
     */

    char *fn_name = mem__get_fn_name(CALLER_FP());
    mem__Block *b;
    mem__Block *old_b;

    if (block == NULL)
        return mem_allocate (by, "extend block 0");

    b = (mem__Block *)block - 1;
    old_b = b;

    debug_output ("memory", "Extending block %p by %d bytes in %s (=>%d)",
                  b,
                  by,
                  (fn_name)?fn_name:"<unknown fn>",
                  mem__total_size());

    /* --- extend existing block --- */
    if ((b = (mem__Block *)reallocate((void *)b, by)) == 0)
        return 0;

    /* --- check to see if block has moved, and if so, update pointers --- */
    if (old_b != b)
    {
        if (b->prev != NULL)
            b->prev->next = b;
        if (b->next != NULL)
            b->next->prev = b;
        if (old_b == mem__block_list)
            mem__block_list = b;
    }

    /* --- adjust size of store --- */
    b->size += by;

    /* --- return "real" block pointer to caller --- */
    debug_output ("memory", " => %d [@%p]\n", mem__total_size(), (void *)(b+1));

    return (void *)(b+1);
}

#else


/* -------------------------------- freeing memory --------------------------------- */

extern void mem_free (void *block)
{

    /*
     * Function to free a given block.
     *
     */

    mem__Block *b = (mem__Block *)block - 1;

    /* --- remove it from list of allocated blocks --- */
    mem__remove_block (b);

    /* --- ... and free the block itself --- */
    free(b);
}


extern void mem_free_all (void)
{
    mem__Block *b = mem__block_list, *next;

    while (b != NULL)
    {
        next = b->next;

        mem__remove_block (b);
        free(b);

        b = next;
    }
}


/* -------------------------------- allocating memory ------------------------------ */

extern void *mem_allocate (unsigned int size)
{

    /*
     * Function to allocate "size" bytes of store.
     *
     * All allocated memory blocks are stored in a linked list, so that
     * we can easily ensure it is freed on finalisation.
     *
     */

    mem__Block *new_block;

    /* --- allocate new block to put in list of all allocated blocks --- */
    if ((new_block = (mem__Block *)calloc(sizeof(mem__Block) + size, sizeof(char))) == 0)
        return 0;
    mem__add_block (new_block);

    /* --- record size of block --- */
    new_block->size = size;

    /* --- return "real" block pointer to caller --- */
    return (void *)(new_block+1);
}

extern void *mem_extend (void *block, int by)
{

    /*
     * Function to extend an existing block of memory
     *
     */

    mem__Block *b;
    mem__Block *old_b;

    if (block == NULL)
        return mem_allocate (by);

    b = (mem__Block *)block - 1;
    old_b = b;

    /* --- extend existing block --- */
    if ((b = (mem__Block *)reallocate((void *)b, by)) == 0)
        return 0;

    /* --- check to see if block has moved, and if so, update pointers --- */
    if (old_b != b)
    {
        if (b->prev != NULL)
            b->prev->next = b;
        if (b->next != NULL)
            b->next->prev = b;
        if (old_b == mem__block_list)
            mem__block_list = b;
    }

    /* --- keep size up to date --- */
    b->size += by;

    /* --- return "real" block pointer to caller --- */
    return (void *)(b+1);
}


#endif