Merge branch 'master' of github.com:thepowersgang/acess2
[tpg/acess2.git] / Usermode / Libraries / ld-acess.so_src / loadlib.c
index a5f0bc2..5307301 100644 (file)
@@ -3,6 +3,8 @@
  By thePowersGang
 */
 #include "common.h"
+#include <stdint.h>
+#include <acess/sys.h>
 
 #define DEBUG  0
 
 #endif
 
 // === PROTOTYPES ===
-Uint   IsFileLoaded(char *file);
- int   GetSymbolFromBase(Uint base, char *name, Uint *ret);
+void   *IsFileLoaded(const char *file);
 
 // === IMPORTS ===
 extern const struct {
-       Uint    Value;
+       void    *Value;
        char    *Name;
 }      caLocalExports[];
 extern const int       ciNumLocalExports;
+extern char    **gEnvP;
+extern char    gLinkedBase[];
 
 // === GLOABLS ===
 tLoadedLib     gLoadedLibraries[MAX_LOADED_LIBRARIES];
@@ -30,7 +33,7 @@ char  *gsNextAvailString = gsLoadedStrings;
 //tLoadLib     *gpLoadedLibraries = NULL;
 
 // === CODE ===
-char *FindLibrary(char *DestBuf, char *SoName, char *ExtraSearchDir)
+const char *FindLibrary(char *DestBuf, const char *SoName, const char *ExtraSearchDir)
 {      
        // -- #1: Executable Specified
        if(ExtraSearchDir)
@@ -54,60 +57,73 @@ char *FindLibrary(char *DestBuf, char *SoName, char *ExtraSearchDir)
 
 /**
  */
-Uint LoadLibrary(char *SoName, char *SearchDir, char **envp)
+void *LoadLibrary(const char *SoName, const char *SearchDir, char **envp)
 {
        char    sTmpName[1024];
-       char    *filename;
-       Uint    iArg;
-       void    (*fEntry)(int, int, char *[], char**);
+       const char      *filename;
+       void    *base;
+       void    (*fEntry)(void *, int, char *[], char**);
        
-       DEBUGS("LoadLibrary: (filename='%s', envp=0x%x)\n", filename, envp);
+       DEBUGS("LoadLibrary: (SoName='%s', SearchDir='%s', envp=%p)", SoName, SearchDir, envp);
        
        // Create Temp Name
        filename = FindLibrary(sTmpName, SoName, SearchDir);
        if(filename == NULL) {
-               DEBUGS("LoadLibrary: RETURN 0\n");
+               DEBUGS("LoadLibrary: RETURN 0");
                return 0;
        }
-       DEBUGS(" LoadLibrary: filename='%s'\n", filename);
-       
-       if( (iArg = IsFileLoaded(filename)) )
-               return iArg;
+       DEBUGS(" LoadLibrary: filename='%s'", filename);
        
+       if( (base = IsFileLoaded(filename)) )
+               return base;
+
+       DEBUGS(" LoadLibrary: SysLoadBin()");   
        // Load Library
-       iArg = SysLoadBin(filename, (Uint*)&fEntry);
-       if(iArg == 0) {
-               DEBUGS("LoadLibrary: RETURN 0\n");
+       base = _SysLoadBin(filename, (void**)&fEntry);
+       if(!base) {
+               DEBUGS("LoadLibrary: RETURN 0");
                return 0;
        }
        
-       DEBUGS(" LoadLibrary: iArg=0x%x, iEntry=0x%x\n", iArg, fEntry);
+       DEBUGS(" LoadLibrary: iArg=%p, fEntry=%p", base, fEntry);
        
        // Load Symbols
-       fEntry = (void*)DoRelocate( iArg, envp, filename );
+       fEntry = DoRelocate( base, envp, filename );
+       if( !fEntry ) {
+               return 0;
+       }
        
        // Call Entrypoint
-       DEBUGS(" LoadLibrary: '%s' Entry 0x%x\n", SoName, fEntry);
-       fEntry(iArg, 0, NULL, envp);
+       DEBUGS(" LoadLibrary: '%s' Entry %p", SoName, fEntry);
+       fEntry(base, 0, NULL, gEnvP);
        
-       DEBUGS("LoadLibrary: RETURN 1\n");
-       return iArg;
+       DEBUGS("LoadLibrary: RETURN 1");
+       return base;
 }
 
 /**
  * \fn Uint IsFileLoaded(char *file)
  * \brief Determine if a file is already loaded
  */
-Uint IsFileLoaded(char *file)
+void *IsFileLoaded(const char *file)
 {
         int    i;
        DEBUGS("IsFileLoaded: (file='%s')", file);
+
+       // Applications link against either libld-acess.so or ld-acess.so
+       if( strcmp(file, "/Acess/Libs/libld-acess.so") == 0
+        || strcmp(file, "/Acess/Libs/ld-acess.so") == 0 )
+       {
+               DEBUGS("IsFileLoaded: Found local (%p)", &gLinkedBase);
+               return &gLinkedBase;
+       }
+
        for( i = 0; i < MAX_LOADED_LIBRARIES; i++ )
        {
                if(gLoadedLibraries[i].Base == 0)       break;  // Last entry has Base set to NULL
                DEBUGS(" strcmp('%s', '%s')", gLoadedLibraries[i].Name, file);
                if(strcmp(gLoadedLibraries[i].Name, file) == 0) {
-                       DEBUGS("IsFileLoaded: Found %i (0x%x)", i, gLoadedLibraries[i].Base);
+                       DEBUGS("IsFileLoaded: Found %i (%p)", i, gLoadedLibraries[i].Base);
                        return gLoadedLibraries[i].Base;
                }
        }
@@ -119,12 +135,12 @@ Uint IsFileLoaded(char *file)
  * \fn void AddLoaded(char *File, Uint base)
  * \brief Add a file to the loaded list
  */
-void AddLoaded(char *File, Uint base)
+void AddLoaded(const char *File, void *base)
 {
         int    i, length;
        char    *name = gsNextAvailString;
        
-       DEBUGS("AddLoaded: (File='%s', base=0x%x)", File, base);
+       DEBUGS("AddLoaded: (File='%s', base=%p)", File, base);
        
        // Find a free slot
        for( i = 0; i < MAX_LOADED_LIBRARIES; i ++ )
@@ -148,14 +164,14 @@ void AddLoaded(char *File, Uint base)
        strcpy(name, File);
        gLoadedLibraries[i].Name = name;
        gsNextAvailString = &name[length+1];
-       DEBUGS("'%s' (0x%x) loaded as %i\n", name, base, i);
+       DEBUGS("'%s' (%p) loaded as %i", name, base, i);
        return;
 }
 
 /**
  * \fn void Unload(Uint Base)
  */
-void Unload(Uint Base)
+void Unload(void *Base)
 {      
         int    i, j;
         int    id;
@@ -167,7 +183,7 @@ void Unload(Uint Base)
        if(id == MAX_LOADED_LIBRARIES)  return;
        
        // Unload Binary
-       SysUnloadBin( Base );
+       _SysUnloadBin( Base );
        // Save String Pointer
        str = gLoadedLibraries[id].Name;
        
@@ -192,28 +208,34 @@ void Unload(Uint Base)
 }
 
 /**
- \fn Uint GetSymbol(char *name)
+ \fn Uint GetSymbol(const char *name)
  \brief Gets a symbol value from a loaded library
 */
-Uint GetSymbol(char *name)
+int GetSymbol(const char *name, void **Value, size_t *Size)
 {
         int    i;
-       Uint    ret;
        
        //SysDebug("ciNumLocalExports = %i", ciNumLocalExports);
        for(i=0;i<ciNumLocalExports;i++)
        {
-               if( strcmp(caLocalExports[i].Name, name) == 0 )
-                       return caLocalExports[i].Value;
+               if( strcmp(caLocalExports[i].Name, name) == 0 ) {
+                       *Value = caLocalExports[i].Value;
+                       if(Size)
+                               *Size = 0;
+                       return 1;
+               }
        }
        
-       for(i=0;i<sizeof(gLoadedLibraries)/sizeof(gLoadedLibraries[0]);i++)
+       // Entry 0 is ld-acess, ignore it
+       for(i = 0; i < MAX_LOADED_LIBRARIES; i ++)
        {
-               if(gLoadedLibraries[i].Base == 0)       break;
+               if(gLoadedLibraries[i].Base == 0)
+                       break;
                
                //SysDebug(" GetSymbol: Trying 0x%x, '%s'",
                //      gLoadedLibraries[i].Base, gLoadedLibraries[i].Name);
-               if(GetSymbolFromBase(gLoadedLibraries[i].Base, name, &ret))     return ret;
+               if(GetSymbolFromBase(gLoadedLibraries[i].Base, name, Value, Size))
+                       return 1;
        }
        SysDebug("GetSymbol: === Symbol '%s' not found ===", name);
        return 0;
@@ -223,13 +245,15 @@ Uint GetSymbol(char *name)
  \fn int GetSymbolFromBase(Uint base, char *name, Uint *ret)
  \breif Gets a symbol from a specified library
 */
-int GetSymbolFromBase(Uint base, char *name, Uint *ret)
+int GetSymbolFromBase(void *base, const char *name, void **ret, size_t *Size)
 {
-       if(*(Uint32*)base == (0x7F|('E'<<8)|('L'<<16)|('F'<<24)))
-               return ElfGetSymbol(base, name, ret);
-       if(*(Uint16*)base == ('M'|('Z'<<8)))
-               return PE_GetSymbol(base, name, ret);
-       SysDebug("Unknown type at %p", base);
+       uint8_t *hdr = base;
+       if(hdr[0] == 0x7F && hdr[1] == 'E' && hdr[2] == 'L' && hdr[3] == 'F')
+               return ElfGetSymbol(base, name, ret, Size);
+       if(hdr[0] == 'M' && hdr[1] == 'Z')
+               return PE_GetSymbol(base, name, ret, Size);
+       SysDebug("Unknown type at %p (%02x %02x %02x %02x)", base,
+               hdr[0], hdr[1], hdr[2], hdr[3]);
        return 0;
 }
 

UCC git Repository :: git.ucc.asn.au