Usermode/ld-acess - Fixing ELF loader modifying .text during relocation
[tpg/acess2.git] / Usermode / Libraries / ld-acess.so_src / loadlib.c
1 /*
2  AcessOS 1 - Dynamic Loader
3  By thePowersGang
4 */
5 #include "common.h"
6
7 #define DEBUG   0
8
9 #if DEBUG
10 # define DEBUGS(v...)   SysDebug(v)
11 #else
12 # define DEBUGS(v...)   
13 #endif
14
15 // === PROTOTYPES ===
16 void    *IsFileLoaded(char *file);
17  int    GetSymbolFromBase(void *base, const char *name, void **ret);
18
19 // === IMPORTS ===
20 extern const struct {
21         void    *Value;
22         char    *Name;
23 }       caLocalExports[];
24 extern const int        ciNumLocalExports;
25
26 // === GLOABLS ===
27 tLoadedLib      gLoadedLibraries[MAX_LOADED_LIBRARIES];
28 char    gsLoadedStrings[MAX_STRINGS_BYTES];
29 char    *gsNextAvailString = gsLoadedStrings;
30 //tLoadLib      *gpLoadedLibraries = NULL;
31
32 // === CODE ===
33 char *FindLibrary(char *DestBuf, char *SoName, char *ExtraSearchDir)
34 {       
35         // -- #1: Executable Specified
36         if(ExtraSearchDir)
37         {
38                 strcpy(DestBuf, ExtraSearchDir);
39                 strcat(DestBuf, "/");
40                 strcat(DestBuf, SoName);
41                 if(file_exists(DestBuf))        return DestBuf;
42         }
43         
44         // -- #2: System
45         strcpy(DestBuf, SYSTEM_LIB_DIR);
46         strcat(DestBuf, SoName);
47         if(file_exists(DestBuf))        return DestBuf;
48         
49         // -- #3: Current Directory
50         if(file_exists(SoName)) return SoName;
51         
52         return NULL;
53 }
54
55 /**
56  */
57 void *LoadLibrary(char *SoName, char *SearchDir, char **envp)
58 {
59         char    sTmpName[1024];
60         char    *filename;
61         void    *base;
62         void    (*fEntry)(void *, int, char *[], char**);
63         
64         DEBUGS("LoadLibrary: (filename='%s', envp=0x%x)\n", filename, envp);
65         
66         // Create Temp Name
67         filename = FindLibrary(sTmpName, SoName, SearchDir);
68         if(filename == NULL) {
69                 DEBUGS("LoadLibrary: RETURN 0\n");
70                 return 0;
71         }
72         DEBUGS(" LoadLibrary: filename='%s'\n", filename);
73         
74         if( (base = IsFileLoaded(filename)) )
75                 return base;
76         
77         // Load Library
78         base = SysLoadBin(filename, (void**)&fEntry);
79         if(!base) {
80                 DEBUGS("LoadLibrary: RETURN 0\n");
81                 return 0;
82         }
83         
84         DEBUGS(" LoadLibrary: iArg=%p, iEntry=0x%x\n", base, fEntry);
85         
86         // Load Symbols
87         fEntry = DoRelocate( base, envp, filename );
88         
89         // Call Entrypoint
90         DEBUGS(" LoadLibrary: '%s' Entry 0x%x\n", SoName, fEntry);
91         fEntry(base, 0, NULL, envp);
92         
93         DEBUGS("LoadLibrary: RETURN 1\n");
94         return base;
95 }
96
97 /**
98  * \fn Uint IsFileLoaded(char *file)
99  * \brief Determine if a file is already loaded
100  */
101 void *IsFileLoaded(char *file)
102 {
103          int    i;
104         DEBUGS("IsFileLoaded: (file='%s')", file);
105         for( i = 0; i < MAX_LOADED_LIBRARIES; i++ )
106         {
107                 if(gLoadedLibraries[i].Base == 0)       break;  // Last entry has Base set to NULL
108                 DEBUGS(" strcmp('%s', '%s')", gLoadedLibraries[i].Name, file);
109                 if(strcmp(gLoadedLibraries[i].Name, file) == 0) {
110                         DEBUGS("IsFileLoaded: Found %i (0x%x)", i, gLoadedLibraries[i].Base);
111                         return gLoadedLibraries[i].Base;
112                 }
113         }
114         DEBUGS("IsFileLoaded: Not Found");
115         return 0;
116 }
117
118 /**
119  * \fn void AddLoaded(char *File, Uint base)
120  * \brief Add a file to the loaded list
121  */
122 void AddLoaded(char *File, void *base)
123 {
124          int    i, length;
125         char    *name = gsNextAvailString;
126         
127         DEBUGS("AddLoaded: (File='%s', base=0x%x)", File, base);
128         
129         // Find a free slot
130         for( i = 0; i < MAX_LOADED_LIBRARIES; i ++ )
131         {
132                 if(gLoadedLibraries[i].Base == 0)       break;
133         }
134         if(i == MAX_LOADED_LIBRARIES) {
135                 SysDebug("ERROR - ld-acess.so has run out of load slots!");
136                 return;
137         }
138         
139         // Check space in string buffer
140         length = strlen(File);
141         if(&name[length+1] >= &gsLoadedStrings[MAX_STRINGS_BYTES]) {
142                 SysDebug("ERROR - ld-acess.so has run out of string buffer memory!");
143                 return;
144         }
145         
146         // Set information
147         gLoadedLibraries[i].Base = base;
148         strcpy(name, File);
149         gLoadedLibraries[i].Name = name;
150         gsNextAvailString = &name[length+1];
151         DEBUGS("'%s' (0x%x) loaded as %i\n", name, base, i);
152         return;
153 }
154
155 /**
156  * \fn void Unload(Uint Base)
157  */
158 void Unload(void *Base)
159 {       
160          int    i, j;
161          int    id;
162         char    *str;
163         for( id = 0; id < MAX_LOADED_LIBRARIES; id++ )
164         {
165                 if(gLoadedLibraries[id].Base == Base)   break;
166         }
167         if(id == MAX_LOADED_LIBRARIES)  return;
168         
169         // Unload Binary
170         SysUnloadBin( Base );
171         // Save String Pointer
172         str = gLoadedLibraries[id].Name;
173         
174         // Compact Loaded List
175         j = id;
176         for( i = j + 1; i < MAX_LOADED_LIBRARIES; i++, j++ )
177         {
178                 if(gLoadedLibraries[i].Base == 0)       break;
179                 // Compact String
180                 strcpy(str, gLoadedLibraries[i].Name);
181                 str += strlen(str)+1;
182                 // Compact Entry
183                 gLoadedLibraries[j].Base = gLoadedLibraries[i].Base;
184                 gLoadedLibraries[j].Name = str;
185         }
186         
187         // NULL Last Entry
188         gLoadedLibraries[j].Base = 0;
189         gLoadedLibraries[j].Name = NULL;
190         // Save next string
191         gsNextAvailString = str;
192 }
193
194 /**
195  \fn Uint GetSymbol(const char *name)
196  \brief Gets a symbol value from a loaded library
197 */
198 void *GetSymbol(const char *name)
199 {
200          int    i;
201         void    *ret;
202         
203         //SysDebug("ciNumLocalExports = %i", ciNumLocalExports);
204         for(i=0;i<ciNumLocalExports;i++)
205         {
206                 if( strcmp(caLocalExports[i].Name, name) == 0 )
207                         return caLocalExports[i].Value;
208         }
209         
210         // Entry 0 is ld-acess, ignore it
211         for(i = 1; i < MAX_LOADED_LIBRARIES; i ++)
212         {
213                 if(gLoadedLibraries[i].Base == 0)       break;
214                 
215                 //SysDebug(" GetSymbol: Trying 0x%x, '%s'",
216                 //      gLoadedLibraries[i].Base, gLoadedLibraries[i].Name);
217                 if(GetSymbolFromBase(gLoadedLibraries[i].Base, name, &ret))     return ret;
218         }
219         SysDebug("GetSymbol: === Symbol '%s' not found ===", name);
220         return 0;
221 }
222
223 /**
224  \fn int GetSymbolFromBase(Uint base, char *name, Uint *ret)
225  \breif Gets a symbol from a specified library
226 */
227 int GetSymbolFromBase(void *base, const char *name, void **ret)
228 {
229         if(*(Uint32*)base == (0x7F|('E'<<8)|('L'<<16)|('F'<<24)))
230                 return ElfGetSymbol(base, name, ret);
231         if(*(Uint16*)base == ('M'|('Z'<<8)))
232                 return PE_GetSymbol(base, name, ret);
233         SysDebug("Unknown type at %p", base);
234         return 0;
235 }
236

UCC git Repository :: git.ucc.asn.au